summaryrefslogtreecommitdiff
path: root/rag/embedder.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-05 14:13:58 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-05 14:13:58 +0300
commitfbc955ca37836553ef4b7c365b84e3dfa859c501 (patch)
treeb5b5d67f7e7ab43c6c91ea9a3e2519b341b650ba /rag/embedder.go
parentc65c11bcfbc563611743d02039420533bcfe9d05 (diff)
Enha: local onnx
Diffstat (limited to 'rag/embedder.go')
-rw-r--r--rag/embedder.go65
1 files changed, 60 insertions, 5 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 1d29877..386d508 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -9,6 +9,10 @@ import (
"gf-lt/models"
"log/slog"
"net/http"
+
+ "github.com/takara-ai/go-tokenizers/tokenizers"
+
+ "github.com/yalue/onnxruntime_go"
)
// Embedder defines the interface for embedding text
@@ -134,11 +138,62 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
return embeddings, nil
}
-// TODO: ONNXEmbedder implementation would go here
-// This would require:
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
-//
-// For now, we'll focus on the API implementation which is already working in the current system,
-// and can be extended later when we have ONNX runtime integration
+
+type ONNXEmbedder struct {
+ session *onnxruntime_go.DynamicAdvancedSession
+ tokenizer *tokenizers.Tokenizer
+ dims int // 768, 512, 256, or 128 for Matryoshka
+}
+
+func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
+ // Batch processing
+ inputs := e.prepareBatch(texts)
+ outputs := make([][]float32, len(texts))
+
+ // Run batch inference (much faster)
+ err := e.session.Run(inputs, outputs)
+ return outputs, err
+}
+
+func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
+ // Load ONNX model
+ session, err := onnxruntime_go.NewDynamicAdvancedSession(
+ modelPath, // onnx/embedgemma/model_q4.onnx
+ []string{"input_ids", "attention_mask"},
+ []string{"sentence_embedding"},
+ nil,
+ )
+ if err != nil {
+ return nil, err
+ }
+ // Load tokenizer (from Hugging Face)
+ tokenizer, err := tokenizers.FromFile("./tokenizer.json")
+ return &ONNXEmbedder{
+ session: session,
+ tokenizer: tokenizer,
+ }, nil
+}
+
+func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
+ // Tokenize
+ tokens := e.tokenizer.Encode(text, true)
+ // Prepare inputs
+ inputIDs := []int64{tokens.GetIds()}
+ attentionMask := []int64{tokens.GetAttentionMask()}
+ // Run inference
+ output := onnxruntime_go.NewEmptyTensor[float32](
+ onnxruntime_go.NewShape(1, 768),
+ )
+ err := e.session.Run(
+ map[string]any{
+ "input_ids": inputIDs,
+ "attention_mask": attentionMask,
+ },
+ []string{"sentence_embedding"},
+ []any{&output},
+ )
+ return output.GetData(), nil
+}