From fbc955ca37836553ef4b7c365b84e3dfa859c501 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Thu, 5 Mar 2026 14:13:58 +0300 Subject: Enha: local onnx --- rag/embedder.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- rag/rag.go | 2 +- 2 files changed, 61 insertions(+), 6 deletions(-) (limited to 'rag') diff --git a/rag/embedder.go b/rag/embedder.go index 1d29877..386d508 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -9,6 +9,10 @@ import ( "gf-lt/models" "log/slog" "net/http" + + "github.com/takara-ai/go-tokenizers/tokenizers" + + "github.com/yalue/onnxruntime_go" ) // Embedder defines the interface for embedding text @@ -134,11 +138,62 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { return embeddings, nil } -// TODO: ONNXEmbedder implementation would go here -// This would require: // 1. Loading ONNX models locally // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls -// -// For now, we'll focus on the API implementation which is already working in the current system, -// and can be extended later when we have ONNX runtime integration + +type ONNXEmbedder struct { + session *onnxruntime_go.DynamicAdvancedSession + tokenizer *tokenizers.Tokenizer + dims int // 768, 512, 256, or 128 for Matryoshka +} + +func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { + // Batch processing + inputs := e.prepareBatch(texts) + outputs := make([][]float32, len(texts)) + + // Run batch inference (much faster) + err := e.session.Run(inputs, outputs) + return outputs, err +} + +func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) { + // Load ONNX model + session, err := onnxruntime_go.NewDynamicAdvancedSession( + modelPath, // onnx/embedgemma/model_q4.onnx + []string{"input_ids", "attention_mask"}, + []string{"sentence_embedding"}, + nil, + ) + if err != nil { + return nil, err + } + // Load tokenizer (from Hugging Face) + tokenizer, err := tokenizers.FromFile("./tokenizer.json") + return &ONNXEmbedder{ + session: session, + tokenizer: tokenizer, + }, nil +} + +func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { + // Tokenize + tokens := e.tokenizer.Encode(text, true) + // Prepare inputs + inputIDs := []int64{tokens.GetIds()} + attentionMask := []int64{tokens.GetAttentionMask()} + // Run inference + output := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(1, 768), + ) + err := e.session.Run( + map[string]any{ + "input_ids": inputIDs, + "attention_mask": attentionMask, + }, + []string{"sentence_embedding"}, + []any{&output}, + ) + return output.GetData(), nil +} diff --git a/rag/rag.go b/rag/rag.go index b63cb08..3d0f38f 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -246,7 +246,7 @@ func (r *RAG) extractImportantPhrases(query string) string { break } } - if isImportant || len(word) > 3 { + if isImportant || len(word) >= 3 { important = append(important, word) } } -- cgit v1.2.3