summaryrefslogtreecommitdiff
path: root/rag/embedder.go
diff options
context:
space:
mode:
Diffstat (limited to 'rag/embedder.go')
-rw-r--r--rag/embedder.go181
1 files changed, 145 insertions, 36 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 386d508..396f04b 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -10,8 +10,8 @@ import (
"log/slog"
"net/http"
- "github.com/takara-ai/go-tokenizers/tokenizers"
-
+ "github.com/sugarme/tokenizer"
+ "github.com/sugarme/tokenizer/pretrained"
"github.com/yalue/onnxruntime_go"
)
@@ -141,59 +141,168 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
-
type ONNXEmbedder struct {
session *onnxruntime_go.DynamicAdvancedSession
- tokenizer *tokenizers.Tokenizer
- dims int // 768, 512, 256, or 128 for Matryoshka
+ tokenizer *tokenizer.Tokenizer
+ dims int // embedding dimension (e.g., 768)
+ logger *slog.Logger
}
-func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
- // Batch processing
- inputs := e.prepareBatch(texts)
- outputs := make([][]float32, len(texts))
-
- // Run batch inference (much faster)
- err := e.session.Run(inputs, outputs)
- return outputs, err
-}
-
-func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) {
- // Load ONNX model
+func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) {
+ // Load tokenizer using sugarme/tokenizer
+ tok, err := pretrained.FromFile(tokenizerPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load tokenizer: %w", err)
+ }
+ // Create ONNX session
session, err := onnxruntime_go.NewDynamicAdvancedSession(
modelPath, // onnx/embedgemma/model_q4.onnx
[]string{"input_ids", "attention_mask"},
[]string{"sentence_embedding"},
- nil,
+ nil, // optional options
)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to create ONNX session: %w", err)
}
- // Load tokenizer (from Hugging Face)
- tokenizer, err := tokenizers.FromFile("./tokenizer.json")
return &ONNXEmbedder{
session: session,
- tokenizer: tokenizer,
+ tokenizer: tok,
+ dims: dims,
+ logger: logger,
}, nil
}
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
- // Tokenize
- tokens := e.tokenizer.Encode(text, true)
- // Prepare inputs
- inputIDs := []int64{tokens.GetIds()}
- attentionMask := []int64{tokens.GetAttentionMask()}
- // Run inference
- output := onnxruntime_go.NewEmptyTensor[float32](
- onnxruntime_go.NewShape(1, 768),
+ // 1. Tokenize
+ encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
+ if err != nil {
+ return nil, fmt.Errorf("tokenization failed: %w", err)
+ }
+ // Convert []int32 to []int64 for ONNX
+ inputIDs := make([]int64, len(encoding.GetIDs()))
+ for i, id := range encoding.GetIDs() {
+ inputIDs[i] = int64(id)
+ }
+ attentionMask := make([]int64, len(encoding.GetAttentionMask()))
+ for i, m := range encoding.GetAttentionMask() {
+ attentionMask[i] = int64(m)
+ }
+ // 2. Create input tensors (shape: [1, seq_len])
+ seqLen := int64(len(inputIDs))
+ inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
+ }
+ defer inputIDsTensor.Destroy()
+ maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
+ }
+ defer maskTensor.Destroy()
+ // 3. Create output tensor (shape: [1, dims])
+ outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create output tensor: %w", err)
+ }
+ defer outputTensor.Destroy()
+ // 4. Run inference
+ err = e.session.Run(
+ map[string]*onnxruntime_go.Tensor{
+ "input_ids": inputIDsTensor,
+ "attention_mask": maskTensor,
+ },
+ []string{"sentence_embedding"},
+ []*onnxruntime_go.Tensor{outputTensor},
)
- err := e.session.Run(
- map[string]any{
- "input_ids": inputIDs,
- "attention_mask": attentionMask,
+ if err != nil {
+ return nil, fmt.Errorf("inference failed: %w", err)
+ }
+ // 5. Extract data
+ outputData := outputTensor.GetData()
+ // outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
+ // We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
+ embedding := make([]float32, len(outputData))
+ copy(embedding, outputData)
+ return embedding, nil
+}
+
+// EmbedSlice (batch) – to be implemented properly
+func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
+ if len(texts) == 0 {
+ return nil, nil
+ }
+ // 1. Tokenize all texts and find max length for padding
+ encodings := make([]*tokenizer.Encoding, len(texts))
+ maxLen := 0
+ for i, txt := range texts {
+ enc, err := e.tokenizer.Encode(txt, true)
+ if err != nil {
+ return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
+ }
+ encodings[i] = enc
+ if l := len(enc.GetIDs()); l > maxLen {
+ maxLen = l
+ }
+ }
+ // 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
+ batchSize := len(texts)
+ inputIDs := make([]int64, batchSize*maxLen)
+ attentionMask := make([]int64, batchSize*maxLen)
+ for i, enc := range encodings {
+ ids := enc.GetIDs()
+ mask := enc.GetAttentionMask()
+ offset := i * maxLen
+ // copy actual tokens
+ for j := 0; j < len(ids); j++ {
+ inputIDs[offset+j] = int64(ids[j])
+ attentionMask[offset+j] = int64(mask[j])
+ }
+ // remaining positions (padding) are already zero-initialized
+ }
+ // 3. Create tensors
+ inputIDsTensor, err := onnxruntime_go.NewTensor(
+ onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
+ inputIDs,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer inputIDsTensor.Destroy()
+ maskTensor, err := onnxruntime_go.NewTensor(
+ onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
+ attentionMask,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer maskTensor.Destroy()
+ outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
+ onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer outputTensor.Destroy()
+ // 4. Run
+ err = e.session.Run(
+ map[string]*onnxruntime_go.Tensor{
+ "input_ids": inputIDsTensor,
+ "attention_mask": maskTensor,
},
[]string{"sentence_embedding"},
- []any{&output},
+ []*onnxruntime_go.Tensor{outputTensor},
)
- return output.GetData(), nil
+ if err != nil {
+ return nil, err
+ }
+ // 5. Extract batch results
+ outputData := outputTensor.GetData()
+ embeddings := make([][]float32, batchSize)
+ for i := 0; i < batchSize; i++ {
+ start := i * e.dims
+ emb := make([]float32, e.dims)
+ copy(emb, outputData[start:start+e.dims])
+ embeddings[i] = emb
+ }
+ return embeddings, nil
}