summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-06 09:32:45 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-06 09:32:45 +0300
commit4ef0a215119924347c2219f4677f11a96358307f (patch)
treea04b9701ff5e6d68c67ce76688b29219bbec74a1
parentd2caebdb4fd3ad148aad20866503b7d46d546404 (diff)
Enha (onnx): unload model if noop for 30s
-rw-r--r--rag/embedder.go13
-rw-r--r--rag/rag.go43
2 files changed, 56 insertions, 0 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 13f6a6e..39f4b5c 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -308,6 +308,19 @@ func (e *ONNXEmbedder) getModelPath() string {
return e.modelPath
}
+func (e *ONNXEmbedder) Destroy() error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.session != nil {
+ if err := e.session.Destroy(); err != nil {
+ return fmt.Errorf("failed to destroy ONNX session: %w", err)
+ }
+ e.session = nil
+ e.logger.Info("ONNX session destroyed, VRAM freed")
+ }
+ return nil
+}
+
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
if err := e.ensureInitialized(); err != nil {
return nil, err
diff --git a/rag/rag.go b/rag/rag.go
index fa30303..d64a3e1 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -12,6 +12,7 @@ import (
"sort"
"strings"
"sync"
+ "time"
"github.com/neurosnap/sentences/english"
)
@@ -32,6 +33,8 @@ type RAG struct {
storage *VectorStorage
mu sync.Mutex
fallbackMsg string
+ idleTimer *time.Timer
+ idleTimeout time.Duration
}
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
@@ -58,6 +61,7 @@ func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
embedder: embedder,
storage: NewVectorStorage(l, s),
fallbackMsg: fallbackMsg,
+ idleTimeout: 30 * time.Second,
}
// Note: Vector tables are created via database migrations, not at runtime
@@ -187,6 +191,7 @@ func (r *RAG) LoadRAG(fpath string) error {
}
}
r.logger.Debug("finished writing vectors", "batches", batchCount)
+ r.resetIdleTimer()
select {
case LongJobStatusCh <- FinishedRAGStatus:
default:
@@ -196,10 +201,12 @@ func (r *RAG) LoadRAG(fpath string) error {
}
func (r *RAG) LineToVector(line string) ([]float32, error) {
+ r.resetIdleTimer()
return r.embedder.Embed(line)
}
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
+ r.resetIdleTimer()
return r.storage.SearchClosest(emb.Embedding)
}
@@ -208,6 +215,7 @@ func (r *RAG) ListLoaded() ([]string, error) {
}
func (r *RAG) RemoveFile(filename string) error {
+ r.resetIdleTimer()
return r.storage.RemoveEmbByFileName(filename)
}
@@ -471,3 +479,38 @@ func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
func GetInstance() *RAG {
return ragInstance
}
+
+func (r *RAG) resetIdleTimer() {
+ if r.idleTimer != nil {
+ r.idleTimer.Stop()
+ }
+ r.idleTimer = time.AfterFunc(r.idleTimeout, func() {
+ r.freeONNXMemory()
+ })
+}
+
+func (r *RAG) freeONNXMemory() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
+ if err := onnx.Destroy(); err != nil {
+ r.logger.Error("failed to free ONNX memory", "error", err)
+ } else {
+ r.logger.Info("freed ONNX VRAM after idle timeout")
+ }
+ }
+}
+
+func (r *RAG) Destroy() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.idleTimer != nil {
+ r.idleTimer.Stop()
+ r.idleTimer = nil
+ }
+ if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
+ if err := onnx.Destroy(); err != nil {
+ r.logger.Error("failed to destroy ONNX embedder", "error", err)
+ }
+ }
+}