diff options
| author | Grail Finder <wohilas@gmail.com> | 2026-03-09 07:07:36 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2026-03-09 07:07:36 +0300 |
| commit | 0e42a6f069ceea40485162c014c04cf718568cfe (patch) | |
| tree | 583a6a6cb91b315e506990a03fdda1b32d0fe985 /rag/embedder.go | |
| parent | 2687f38d00ceaa4f61034e3e02b9b59d08efc017 (diff) | |
| parent | a1b5f9cdc59938901123650fc0900067ac3447ca (diff) | |
Merge branch 'master' into feat/agent-flow
Diffstat (limited to 'rag/embedder.go')
| -rw-r--r-- | rag/embedder.go | 314 |
1 files changed, 307 insertions, 7 deletions
diff --git a/rag/embedder.go b/rag/embedder.go index 1d29877..5a4aae0 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -9,6 +9,13 @@ import ( "gf-lt/models" "log/slog" "net/http" + "os" + "sync" + "time" + + "github.com/sugarme/tokenizer" + "github.com/sugarme/tokenizer/pretrained" + "github.com/yalue/onnxruntime_go" ) // Embedder defines the interface for embedding text @@ -27,8 +34,10 @@ type APIEmbedder struct { func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder { return &APIEmbedder{ logger: l, - client: &http.Client{}, - cfg: cfg, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + cfg: cfg, } } @@ -134,11 +143,302 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { return embeddings, nil } -// TODO: ONNXEmbedder implementation would go here -// This would require: // 1. Loading ONNX models locally // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls -// -// For now, we'll focus on the API implementation which is already working in the current system, -// and can be extended later when we have ONNX runtime integration +type ONNXEmbedder struct { + session *onnxruntime_go.DynamicAdvancedSession + tokenizer *tokenizer.Tokenizer + tokenizerPath string + dims int + logger *slog.Logger + mu sync.Mutex + modelPath string +} + +var onnxInitOnce sync.Once +var onnxReady bool +var onnxLibPath string +var cudaLibPath string + +var onnxLibPaths = []string{ + "/usr/lib/libonnxruntime.so", + "/usr/lib/libonnxruntime.so.1.24.2", + "/usr/local/lib/libonnxruntime.so", + "/usr/lib/x86_64-linux-gnu/libonnxruntime.so", + "/opt/onnxruntime/lib/libonnxruntime.so", +} + +var cudaLibPaths = []string{ + "/usr/lib/libonnxruntime_providers_cuda.so", + "/usr/local/lib/libonnxruntime_providers_cuda.so", + "/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so", +} + +func findONNXLibrary() string { + for _, path := range onnxLibPaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + return "" +} + +func findCUDALibrary() string { + for _, path := range cudaLibPaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + return "" +} + +func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { + // Check if model and tokenizer files exist + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("ONNX model not found: %w", err) + } + if _, err := os.Stat(tokenizerPath); err != nil { + return nil, fmt.Errorf("tokenizer not found: %w", err) + } + + // Find ONNX library + onnxLibPath = findONNXLibrary() + if onnxLibPath == "" { + return nil, errors.New("ONNX runtime library not found in standard locations") + } + + // Find CUDA provider library (optional) + cudaLibPath = findCUDALibrary() + if cudaLibPath == "" { + fmt.Println("WARNING: CUDA provider library not found, will use CPU") + } + emb := &ONNXEmbedder{ + tokenizerPath: tokenizerPath, + dims: dims, + logger: logger, + modelPath: modelPath, + } + return emb, nil +} + +func (e *ONNXEmbedder) ensureInitialized() error { + if e.session != nil { + return nil + } + e.mu.Lock() + defer e.mu.Unlock() + if e.session != nil { + return nil + } + // Load tokenizer lazily + if e.tokenizer == nil { + tok, err := pretrained.FromFile(e.tokenizerPath) + if err != nil { + return fmt.Errorf("failed to load tokenizer: %w", err) + } + e.tokenizer = tok + } + onnxInitOnce.Do(func() { + onnxruntime_go.SetSharedLibraryPath(onnxLibPath) + if err := onnxruntime_go.InitializeEnvironment(); err != nil { + e.logger.Error("failed to initialize ONNX runtime", "error", err) + onnxReady = false + return + } + // Register CUDA provider if available + if cudaLibPath != "" { + if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil { + e.logger.Warn("failed to register CUDA provider", "error", err) + } + } + onnxReady = true + }) + if !onnxReady { + return errors.New("ONNX runtime not ready") + } + // Create session options + opts, err := onnxruntime_go.NewSessionOptions() + if err != nil { + return fmt.Errorf("failed to create session options: %w", err) + } + defer func() { + _ = opts.Destroy() + }() + + // Try to add CUDA provider + useCUDA := cudaLibPath != "" + if useCUDA { + cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions() + if err != nil { + e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err) + useCUDA = false + } else { + defer func() { + _ = cudaOpts.Destroy() + }() + if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil { + e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err) + useCUDA = false + } else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil { + e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err) + useCUDA = false + } + } + } + if useCUDA { + e.logger.Info("Using CUDA for ONNX inference") + } else { + e.logger.Info("Using CPU for ONNX inference") + } + + // Create session with options + session, err := onnxruntime_go.NewDynamicAdvancedSession( + e.getModelPath(), + []string{"input_ids", "attention_mask"}, + []string{"sentence_embedding"}, + opts, + ) + if err != nil { + return fmt.Errorf("failed to create ONNX session: %w", err) + } + e.session = session + return nil +} + +func (e *ONNXEmbedder) getModelPath() string { + return e.modelPath +} + +func (e *ONNXEmbedder) Destroy() error { + e.mu.Lock() + defer e.mu.Unlock() + if e.session != nil { + if err := e.session.Destroy(); err != nil { + return fmt.Errorf("failed to destroy ONNX session: %w", err) + } + e.session = nil + e.logger.Info("ONNX session destroyed, VRAM freed") + } + return nil +} + +func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { + if err := e.ensureInitialized(); err != nil { + return nil, err + } + // 1. Tokenize + encoding, err := e.tokenizer.EncodeSingle(text) + if err != nil { + return nil, fmt.Errorf("tokenization failed: %w", err) + } + // 2. Convert to int64 and create attention mask + ids := encoding.Ids + inputIDs := make([]int64, len(ids)) + attentionMask := make([]int64, len(ids)) + for i, id := range ids { + inputIDs[i] = int64(id) + attentionMask[i] = 1 + } + // 3. Create input tensors (shape: [1, seq_len]) + seqLen := int64(len(inputIDs)) + inputIDsTensor, err := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(1, seqLen), + inputIDs, + ) + if err != nil { + return nil, fmt.Errorf("failed to create input_ids tensor: %w", err) + } + defer func() { _ = inputIDsTensor.Destroy() }() + maskTensor, err := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(1, seqLen), + attentionMask, + ) + if err != nil { + return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err) + } + defer func() { _ = maskTensor.Destroy() }() + // 4. Create output tensor + outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(1, int64(e.dims)), + ) + if err != nil { + return nil, fmt.Errorf("failed to create output tensor: %w", err) + } + defer func() { _ = outputTensor.Destroy() }() + // 5. Run inference + err = e.session.Run( + []onnxruntime_go.Value{inputIDsTensor, maskTensor}, + []onnxruntime_go.Value{outputTensor}, + ) + if err != nil { + return nil, fmt.Errorf("inference failed: %w", err) + } + // 6. Copy output data + outputData := outputTensor.GetData() + embedding := make([]float32, len(outputData)) + copy(embedding, outputData) + return embedding, nil +} + +func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { + if err := e.ensureInitialized(); err != nil { + return nil, err + } + encodings := make([]*tokenizer.Encoding, len(texts)) + maxLen := 0 + for i, txt := range texts { + enc, err := e.tokenizer.EncodeSingle(txt) + if err != nil { + return nil, err + } + encodings[i] = enc + if l := len(enc.Ids); l > maxLen { + maxLen = l + } + } + batchSize := len(texts) + inputIDs := make([]int64, batchSize*maxLen) + attentionMask := make([]int64, batchSize*maxLen) + for i, enc := range encodings { + ids := enc.Ids + offset := i * maxLen + for j, id := range ids { + inputIDs[offset+j] = int64(id) + attentionMask[offset+j] = 1 + } + // Remaining positions are already zero (padding) + } + // Create tensors with shape [batchSize, maxLen] + inputTensor, _ := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), + inputIDs, + ) + defer func() { _ = inputTensor.Destroy() }() + maskTensor, _ := onnxruntime_go.NewTensor[int64]( + onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), + attentionMask, + ) + defer func() { _ = maskTensor.Destroy() }() + outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)), + ) + defer func() { _ = outputTensor.Destroy() }() + err := e.session.Run( + []onnxruntime_go.Value{inputTensor, maskTensor}, + []onnxruntime_go.Value{outputTensor}, + ) + if err != nil { + return nil, err + } + // Extract embeddings per batch item + data := outputTensor.GetData() + embeddings := make([][]float32, batchSize) + for i := 0; i < batchSize; i++ { + start := i * e.dims + emb := make([]float32, e.dims) + copy(emb, data[start:start+e.dims]) + embeddings[i] = emb + } + return embeddings, nil +} |
