diff options
| author | Grail Finder <wohilas@gmail.com> | 2026-03-05 14:27:19 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2026-03-05 14:27:19 +0300 |
| commit | 7c56e27dbe904b3c08b3eee375542011458e297c (patch) | |
| tree | 5ae4d5a56f282d751cec3a54a9aa1335105a56cd /rag | |
| parent | fbc955ca37836553ef4b7c365b84e3dfa859c501 (diff) | |
Dep: trying sugarme tokenizer
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/embedder.go | 181 |
1 files changed, 145 insertions, 36 deletions
diff --git a/rag/embedder.go b/rag/embedder.go index 386d508..396f04b 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -10,8 +10,8 @@ import ( "log/slog" "net/http" - "github.com/takara-ai/go-tokenizers/tokenizers" - + "github.com/sugarme/tokenizer" + "github.com/sugarme/tokenizer/pretrained" "github.com/yalue/onnxruntime_go" ) @@ -141,59 +141,168 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { // 1. Loading ONNX models locally // 2. Using a Go ONNX runtime (like gorgonia/onnx or similar) // 3. Converting text to embeddings without external API calls - type ONNXEmbedder struct { session *onnxruntime_go.DynamicAdvancedSession - tokenizer *tokenizers.Tokenizer - dims int // 768, 512, 256, or 128 for Matryoshka + tokenizer *tokenizer.Tokenizer + dims int // embedding dimension (e.g., 768) + logger *slog.Logger } -func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { - // Batch processing - inputs := e.prepareBatch(texts) - outputs := make([][]float32, len(texts)) - - // Run batch inference (much faster) - err := e.session.Run(inputs, outputs) - return outputs, err -} - -func NewONNXEmbedder(modelPath string) (*ONNXEmbedder, error) { - // Load ONNX model +func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Logger) (*ONNXEmbedder, error) { + // Load tokenizer using sugarme/tokenizer + tok, err := pretrained.FromFile(tokenizerPath) + if err != nil { + return nil, fmt.Errorf("failed to load tokenizer: %w", err) + } + // Create ONNX session session, err := onnxruntime_go.NewDynamicAdvancedSession( modelPath, // onnx/embedgemma/model_q4.onnx []string{"input_ids", "attention_mask"}, []string{"sentence_embedding"}, - nil, + nil, // optional options ) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create ONNX session: %w", err) } - // Load tokenizer (from Hugging Face) - tokenizer, err := tokenizers.FromFile("./tokenizer.json") return &ONNXEmbedder{ session: session, - tokenizer: tokenizer, + tokenizer: tok, + dims: dims, + logger: logger, }, nil } func (e *ONNXEmbedder) Embed(text string) ([]float32, error) { - // Tokenize - tokens := e.tokenizer.Encode(text, true) - // Prepare inputs - inputIDs := []int64{tokens.GetIds()} - attentionMask := []int64{tokens.GetAttentionMask()} - // Run inference - output := onnxruntime_go.NewEmptyTensor[float32]( - onnxruntime_go.NewShape(1, 768), + // 1. Tokenize + encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens + if err != nil { + return nil, fmt.Errorf("tokenization failed: %w", err) + } + // Convert []int32 to []int64 for ONNX + inputIDs := make([]int64, len(encoding.GetIDs())) + for i, id := range encoding.GetIDs() { + inputIDs[i] = int64(id) + } + attentionMask := make([]int64, len(encoding.GetAttentionMask())) + for i, m := range encoding.GetAttentionMask() { + attentionMask[i] = int64(m) + } + // 2. Create input tensors (shape: [1, seq_len]) + seqLen := int64(len(inputIDs)) + inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs) + if err != nil { + return nil, fmt.Errorf("failed to create input_ids tensor: %w", err) + } + defer inputIDsTensor.Destroy() + maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask) + if err != nil { + return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err) + } + defer maskTensor.Destroy() + // 3. Create output tensor (shape: [1, dims]) + outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims))) + if err != nil { + return nil, fmt.Errorf("failed to create output tensor: %w", err) + } + defer outputTensor.Destroy() + // 4. Run inference + err = e.session.Run( + map[string]*onnxruntime_go.Tensor{ + "input_ids": inputIDsTensor, + "attention_mask": maskTensor, + }, + []string{"sentence_embedding"}, + []*onnxruntime_go.Tensor{outputTensor}, ) - err := e.session.Run( - map[string]any{ - "input_ids": inputIDs, - "attention_mask": attentionMask, + if err != nil { + return nil, fmt.Errorf("inference failed: %w", err) + } + // 5. Extract data + outputData := outputTensor.GetData() + // outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy. + // We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now). + embedding := make([]float32, len(outputData)) + copy(embedding, outputData) + return embedding, nil +} + +// EmbedSlice (batch) – to be implemented properly +func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + // 1. Tokenize all texts and find max length for padding + encodings := make([]*tokenizer.Encoding, len(texts)) + maxLen := 0 + for i, txt := range texts { + enc, err := e.tokenizer.Encode(txt, true) + if err != nil { + return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err) + } + encodings[i] = enc + if l := len(enc.GetIDs()); l > maxLen { + maxLen = l + } + } + // 2. Build padded input_ids and attention_mask (shape: [batch, maxLen]) + batchSize := len(texts) + inputIDs := make([]int64, batchSize*maxLen) + attentionMask := make([]int64, batchSize*maxLen) + for i, enc := range encodings { + ids := enc.GetIDs() + mask := enc.GetAttentionMask() + offset := i * maxLen + // copy actual tokens + for j := 0; j < len(ids); j++ { + inputIDs[offset+j] = int64(ids[j]) + attentionMask[offset+j] = int64(mask[j]) + } + // remaining positions (padding) are already zero-initialized + } + // 3. Create tensors + inputIDsTensor, err := onnxruntime_go.NewTensor( + onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), + inputIDs, + ) + if err != nil { + return nil, err + } + defer inputIDsTensor.Destroy() + maskTensor, err := onnxruntime_go.NewTensor( + onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)), + attentionMask, + ) + if err != nil { + return nil, err + } + defer maskTensor.Destroy() + outputTensor, err := onnxruntime_go.NewEmptyTensor[float32]( + onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)), + ) + if err != nil { + return nil, err + } + defer outputTensor.Destroy() + // 4. Run + err = e.session.Run( + map[string]*onnxruntime_go.Tensor{ + "input_ids": inputIDsTensor, + "attention_mask": maskTensor, }, []string{"sentence_embedding"}, - []any{&output}, + []*onnxruntime_go.Tensor{outputTensor}, ) - return output.GetData(), nil + if err != nil { + return nil, err + } + // 5. Extract batch results + outputData := outputTensor.GetData() + embeddings := make([][]float32, batchSize) + for i := 0; i < batchSize; i++ { + start := i * e.dims + emb := make([]float32, e.dims) + copy(emb, outputData[start:start+e.dims]) + embeddings[i] = emb + } + return embeddings, nil } |
