summaryrefslogtreecommitdiff
path: root/rag/embedder.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-09 07:07:36 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-09 07:07:36 +0300
commit0e42a6f069ceea40485162c014c04cf718568cfe (patch)
tree583a6a6cb91b315e506990a03fdda1b32d0fe985 /rag/embedder.go
parent2687f38d00ceaa4f61034e3e02b9b59d08efc017 (diff)
parenta1b5f9cdc59938901123650fc0900067ac3447ca (diff)
Merge branch 'master' into feat/agent-flow
Diffstat (limited to 'rag/embedder.go')
-rw-r--r--rag/embedder.go314
1 files changed, 307 insertions, 7 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 1d29877..5a4aae0 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -9,6 +9,13 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
+ "os"
+ "sync"
+ "time"
+
+ "github.com/sugarme/tokenizer"
+ "github.com/sugarme/tokenizer/pretrained"
+ "github.com/yalue/onnxruntime_go"
)
// Embedder defines the interface for embedding text
@@ -27,8 +34,10 @@ type APIEmbedder struct {
func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
return &APIEmbedder{
logger: l,
- client: &http.Client{},
- cfg: cfg,
+ client: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ cfg: cfg,
}
}
@@ -134,11 +143,302 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
return embeddings, nil
}
-// TODO: ONNXEmbedder implementation would go here
-// This would require:
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
-//
-// For now, we'll focus on the API implementation which is already working in the current system,
-// and can be extended later when we have ONNX runtime integration
+type ONNXEmbedder struct {
+ session *onnxruntime_go.DynamicAdvancedSession
+ tokenizer *tokenizer.Tokenizer
+ tokenizerPath string
+ dims int
+ logger *slog.Logger
+ mu sync.Mutex
+ modelPath string
+}
+
+var onnxInitOnce sync.Once
+var onnxReady bool
+var onnxLibPath string
+var cudaLibPath string
+
+var onnxLibPaths = []string{
+ "/usr/lib/libonnxruntime.so",
+ "/usr/lib/libonnxruntime.so.1.24.2",
+ "/usr/local/lib/libonnxruntime.so",
+ "/usr/lib/x86_64-linux-gnu/libonnxruntime.so",
+ "/opt/onnxruntime/lib/libonnxruntime.so",
+}
+
+var cudaLibPaths = []string{
+ "/usr/lib/libonnxruntime_providers_cuda.so",
+ "/usr/local/lib/libonnxruntime_providers_cuda.so",
+ "/opt/onnxruntime/lib/libonnxruntime_providers_cuda.so",
+}
+
+func findONNXLibrary() string {
+ for _, path := range onnxLibPaths {
+ if _, err := os.Stat(path); err == nil {
+ return path
+ }
+ }
+ return ""
+}
+
+func findCUDALibrary() string {
+ for _, path := range cudaLibPaths {
+ if _, err := os.Stat(path); err == nil {
+ return path
+ }
+ }
+ return ""
+}
+
+func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
+ // Check if model and tokenizer files exist
+ if _, err := os.Stat(modelPath); err != nil {
+ return nil, fmt.Errorf("ONNX model not found: %w", err)
+ }
+ if _, err := os.Stat(tokenizerPath); err != nil {
+ return nil, fmt.Errorf("tokenizer not found: %w", err)
+ }
+
+ // Find ONNX library
+ onnxLibPath = findONNXLibrary()
+ if onnxLibPath == "" {
+ return nil, errors.New("ONNX runtime library not found in standard locations")
+ }
+
+ // Find CUDA provider library (optional)
+ cudaLibPath = findCUDALibrary()
+ if cudaLibPath == "" {
+ fmt.Println("WARNING: CUDA provider library not found, will use CPU")
+ }
+ emb := &ONNXEmbedder{
+ tokenizerPath: tokenizerPath,
+ dims: dims,
+ logger: logger,
+ modelPath: modelPath,
+ }
+ return emb, nil
+}
+
+func (e *ONNXEmbedder) ensureInitialized() error {
+ if e.session != nil {
+ return nil
+ }
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.session != nil {
+ return nil
+ }
+ // Load tokenizer lazily
+ if e.tokenizer == nil {
+ tok, err := pretrained.FromFile(e.tokenizerPath)
+ if err != nil {
+ return fmt.Errorf("failed to load tokenizer: %w", err)
+ }
+ e.tokenizer = tok
+ }
+ onnxInitOnce.Do(func() {
+ onnxruntime_go.SetSharedLibraryPath(onnxLibPath)
+ if err := onnxruntime_go.InitializeEnvironment(); err != nil {
+ e.logger.Error("failed to initialize ONNX runtime", "error", err)
+ onnxReady = false
+ return
+ }
+ // Register CUDA provider if available
+ if cudaLibPath != "" {
+ if err := onnxruntime_go.RegisterExecutionProviderLibrary("CUDA", cudaLibPath); err != nil {
+ e.logger.Warn("failed to register CUDA provider", "error", err)
+ }
+ }
+ onnxReady = true
+ })
+ if !onnxReady {
+ return errors.New("ONNX runtime not ready")
+ }
+ // Create session options
+ opts, err := onnxruntime_go.NewSessionOptions()
+ if err != nil {
+ return fmt.Errorf("failed to create session options: %w", err)
+ }
+ defer func() {
+ _ = opts.Destroy()
+ }()
+
+ // Try to add CUDA provider
+ useCUDA := cudaLibPath != ""
+ if useCUDA {
+ cudaOpts, err := onnxruntime_go.NewCUDAProviderOptions()
+ if err != nil {
+ e.logger.Warn("failed to create CUDA provider options, falling back to CPU", "error", err)
+ useCUDA = false
+ } else {
+ defer func() {
+ _ = cudaOpts.Destroy()
+ }()
+ if err := cudaOpts.Update(map[string]string{"device_id": "0"}); err != nil {
+ e.logger.Warn("failed to update CUDA options, falling back to CPU", "error", err)
+ useCUDA = false
+ } else if err := opts.AppendExecutionProviderCUDA(cudaOpts); err != nil {
+ e.logger.Warn("failed to append CUDA provider, falling back to CPU", "error", err)
+ useCUDA = false
+ }
+ }
+ }
+ if useCUDA {
+ e.logger.Info("Using CUDA for ONNX inference")
+ } else {
+ e.logger.Info("Using CPU for ONNX inference")
+ }
+
+ // Create session with options
+ session, err := onnxruntime_go.NewDynamicAdvancedSession(
+ e.getModelPath(),
+ []string{"input_ids", "attention_mask"},
+ []string{"sentence_embedding"},
+ opts,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create ONNX session: %w", err)
+ }
+ e.session = session
+ return nil
+}
+
+func (e *ONNXEmbedder) getModelPath() string {
+ return e.modelPath
+}
+
+func (e *ONNXEmbedder) Destroy() error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.session != nil {
+ if err := e.session.Destroy(); err != nil {
+ return fmt.Errorf("failed to destroy ONNX session: %w", err)
+ }
+ e.session = nil
+ e.logger.Info("ONNX session destroyed, VRAM freed")
+ }
+ return nil
+}
+
+func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
+ if err := e.ensureInitialized(); err != nil {
+ return nil, err
+ }
+ // 1. Tokenize
+ encoding, err := e.tokenizer.EncodeSingle(text)
+ if err != nil {
+ return nil, fmt.Errorf("tokenization failed: %w", err)
+ }
+ // 2. Convert to int64 and create attention mask
+ ids := encoding.Ids
+ inputIDs := make([]int64, len(ids))
+ attentionMask := make([]int64, len(ids))
+ for i, id := range ids {
+ inputIDs[i] = int64(id)
+ attentionMask[i] = 1
+ }
+ // 3. Create input tensors (shape: [1, seq_len])
+ seqLen := int64(len(inputIDs))
+ inputIDsTensor, err := onnxruntime_go.NewTensor[int64](
+ onnxruntime_go.NewShape(1, seqLen),
+ inputIDs,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
+ }
+ defer func() { _ = inputIDsTensor.Destroy() }()
+ maskTensor, err := onnxruntime_go.NewTensor[int64](
+ onnxruntime_go.NewShape(1, seqLen),
+ attentionMask,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
+ }
+ defer func() { _ = maskTensor.Destroy() }()
+ // 4. Create output tensor
+ outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
+ onnxruntime_go.NewShape(1, int64(e.dims)),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create output tensor: %w", err)
+ }
+ defer func() { _ = outputTensor.Destroy() }()
+ // 5. Run inference
+ err = e.session.Run(
+ []onnxruntime_go.Value{inputIDsTensor, maskTensor},
+ []onnxruntime_go.Value{outputTensor},
+ )
+ if err != nil {
+ return nil, fmt.Errorf("inference failed: %w", err)
+ }
+ // 6. Copy output data
+ outputData := outputTensor.GetData()
+ embedding := make([]float32, len(outputData))
+ copy(embedding, outputData)
+ return embedding, nil
+}
+
+func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
+ if err := e.ensureInitialized(); err != nil {
+ return nil, err
+ }
+ encodings := make([]*tokenizer.Encoding, len(texts))
+ maxLen := 0
+ for i, txt := range texts {
+ enc, err := e.tokenizer.EncodeSingle(txt)
+ if err != nil {
+ return nil, err
+ }
+ encodings[i] = enc
+ if l := len(enc.Ids); l > maxLen {
+ maxLen = l
+ }
+ }
+ batchSize := len(texts)
+ inputIDs := make([]int64, batchSize*maxLen)
+ attentionMask := make([]int64, batchSize*maxLen)
+ for i, enc := range encodings {
+ ids := enc.Ids
+ offset := i * maxLen
+ for j, id := range ids {
+ inputIDs[offset+j] = int64(id)
+ attentionMask[offset+j] = 1
+ }
+ // Remaining positions are already zero (padding)
+ }
+ // Create tensors with shape [batchSize, maxLen]
+ inputTensor, _ := onnxruntime_go.NewTensor[int64](
+ onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
+ inputIDs,
+ )
+ defer func() { _ = inputTensor.Destroy() }()
+ maskTensor, _ := onnxruntime_go.NewTensor[int64](
+ onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
+ attentionMask,
+ )
+ defer func() { _ = maskTensor.Destroy() }()
+ outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
+ onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
+ )
+ defer func() { _ = outputTensor.Destroy() }()
+ err := e.session.Run(
+ []onnxruntime_go.Value{inputTensor, maskTensor},
+ []onnxruntime_go.Value{outputTensor},
+ )
+ if err != nil {
+ return nil, err
+ }
+ // Extract embeddings per batch item
+ data := outputTensor.GetData()
+ embeddings := make([][]float32, batchSize)
+ for i := 0; i < batchSize; i++ {
+ start := i * e.dims
+ emb := make([]float32, e.dims)
+ copy(emb, data[start:start+e.dims])
+ embeddings[i] = emb
+ }
+ return embeddings, nil
+}