diff options
| author | Grail Finder <wohilas@gmail.com> | 2026-03-05 20:02:46 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2026-03-05 20:02:46 +0300 |
| commit | efc92d884c36498220e2b8d5ad9e02f84e42d953 (patch) | |
| tree | 6361de7107d077e39c1aeb312013041d40f73167 /rag/embedder.go | |
| parent | ac8c8bb0558a00cf0d025ab8522aaa57b8cba7de (diff) | |
Chore: onnx library lookup
Diffstat (limited to 'rag/embedder.go')
| -rw-r--r-- | rag/embedder.go | 113 |
1 files changed, 89 insertions, 24 deletions
diff --git a/rag/embedder.go b/rag/embedder.go index b0a3226..59dbfd2 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -9,6 +9,7 @@ import ( "gf-lt/models" "log/slog" "net/http" + "os" "sync" "github.com/sugarme/tokenizer" @@ -143,47 +144,111 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls type ONNXEmbedder struct { - session *onnxruntime_go.DynamicAdvancedSession - tokenizer *tokenizer.Tokenizer - dims int // embedding dimension (e.g., 768) - logger *slog.Logger + session *onnxruntime_go.DynamicAdvancedSession + tokenizer *tokenizer.Tokenizer + tokenizerPath string + dims int + logger *slog.Logger + mu sync.Mutex + modelPath string } var onnxInitOnce sync.Once +var onnxReady bool +var onnxLibPath string + +var onnxLibPaths = []string{ + "/usr/lib/libonnxruntime.so", + "/usr/local/lib/libonnxruntime.so", + "/usr/lib/x86_64-linux-gnu/libonnxruntime.so", + "/opt/onnxruntime/lib/libonnxruntime.so", +} + +func findONNXLibrary() string { + for _, path := range onnxLibPaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + return "" +} func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { - // Initialize ONNX runtime environment once - onnxInitOnce.Do(func() { - onnxruntime_go.SetSharedLibraryPath("/usr/local/lib/libonnxruntime.so") - err := onnxruntime_go.InitializeEnvironment() + // Check if model and tokenizer files exist + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("ONNX model not found: %w", err) + } + if _, err := os.Stat(tokenizerPath); err != nil { + return nil, fmt.Errorf("tokenizer not found: %w", err) + } + + // Find ONNX library + onnxLibPath = findONNXLibrary() + if onnxLibPath == "" { + return nil, errors.New("ONNX runtime library not found in standard locations") + } + + emb := &ONNXEmbedder{ + tokenizerPath: tokenizerPath, + dims: dims, + logger: logger, + modelPath: modelPath, + } + return emb, nil +} + +func (e *ONNXEmbedder) ensureInitialized() error { + if e.session != nil { + return nil + } + e.mu.Lock() + defer e.mu.Unlock() + if e.session != nil { + return nil + } + + // Load tokenizer lazily + if e.tokenizer == nil { + tok, err := pretrained.FromFile(e.tokenizerPath) if err != nil { - logger.Error("failed to initialize ONNX runtime", "error", err) + return fmt.Errorf("failed to load tokenizer: %w", err) + } + e.tokenizer = tok + } + + onnxInitOnce.Do(func() { + onnxruntime_go.SetSharedLibraryPath(onnxLibPath) + if err := onnxruntime_go.InitializeEnvironment(); err != nil { + e.logger.Error("failed to initialize ONNX runtime", "error", err) + onnxReady = false + return } + onnxReady = true }) - // Load tokenizer using sugarme/tokenizer - tok, err := pretrained.FromFile(tokenizerPath) - if err != nil { - return nil, fmt.Errorf("failed to load tokenizer: %w", err) + if !onnxReady { + return errors.New("ONNX runtime not ready") } - // Create ONNX session session, err := onnxruntime_go.NewDynamicAdvancedSession( - modelPath, // onnx/embedgemma/model_q4.onnx + e.getModelPath(), []string{"input_ids", "attention_mask"}, []string{"sentence_embedding"}, - nil, // optional options + nil, ) if err != nil { - return nil, fmt.Errorf("failed to create ONNX session: %w", err) - } - return &ONNXEmbedder{ - session: session, - tokenizer: tok, - dims: dims, - logger: logger, - }, nil + return fmt.Errorf("failed to create ONNX session: %w", err) + } + e.session = session + return nil +} + +func (e *ONNXEmbedder) getModelPath() string { + return e.modelPath } func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { + if err := e.ensureInitialized(); err != nil { + return nil, err + } // 1. Tokenize encoding, err := e.tokenizer.EncodeSingle(text) if err != nil { |
