From fbc955ca37836553ef4b7c365b84e3dfa859c501 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Thu, 5 Mar 2026 14:13:58 +0300 Subject: Enha: local onnx --- rag/embedder.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 5 deletions(-) (limited to 'rag/embedder.go') diff --git a/rag/embedder.go b/rag/embedder.go index 1d29877..386d508 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -9,6 +9,10 @@ import ( "gf-lt/models" "log/slog" "net/http" + + "github.com/takara-ai/go-tokenizers/tokenizers" + + "github.com/yalue/onnxruntime_go" ) // Embedder defines the interface for embedding text @@ -134,11 +138,62 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { return embeddings, nil } -// TODO: ONNXEmbedder implementation would go here -// This would require: // 1. Loading ONNX models locally // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls -// -// For now, we'll focus on the API implementation which is already working in the current system, -// and can be extended later when we have ONNX runtime integration + +type ONNXEmbedder struct { + session *onnxruntime_go.DynamicAdvancedSession + tokenizer *tokenizers.Tokenizer + dims int // 768, 512, 256, or 128 for Matryoshka +} + +func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { + // Batch processing + inputs := e.prepareBatch(texts) + outputs := make([][]float32, len(texts)) + + // Run batch inference (much faster) + err := e.session.Run(inputs, outputs) + return outputs, err +} + +func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) { + // Load ONNX model + session, err := onnxruntime_go.NewDynamicAdvancedSession( + modelPath, // onnx/embedgemma/model_q4.onnx + []string{"input_ids", "attention_mask"}, + []string{"sentence_embedding"}, + nil, + ) + if err != nil { + return nil, err + } + // Load tokenizer (from Hugging Face) + tokenizer, err := tokenizers.FromFile("./tokenizer.json") + return &ONNXEmbedder{ + session: session, + tokenizer: tokenizer, + }, nil +} + +func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { + // Tokenize + tokens := e.tokenizer.Encode(text, true) + // Prepare inputs + inputIDs := []int64{tokens.GetIds()} + attentionMask := []int64{tokens.GetAttentionMask()} + // Run inference + output := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(1, 768), + ) + err := e.session.Run( + map[string]any{ + "input_ids": inputIDs, + "attention_mask": attentionMask, + }, + []string{"sentence_embedding"}, + []any{&output}, + ) + return output.GetData(), nil +} -- cgit v1.2.3