summaryrefslogtreecommitdiff
path: root/rag/embedder.go
diff options
context:
space:
mode:
Diffstat (limited to 'rag/embedder.go')
-rw-r--r--rag/embedder.go105
1 files changed, 43 insertions, 62 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 396f04b..988d91e 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -174,134 +174,115 @@ func NewONNXEmbedder(modelPath, tokenizerPath string, dims int, logger *slog.Log
func (e *ONNXEmbedder) Embed(text string) ([]float32, error) {
// 1. Tokenize
- encoding, err := e.tokenizer.Encode(text, true) // true = add special tokens
+ encoding, err := e.tokenizer.EncodeSingle(text)
if err != nil {
return nil, fmt.Errorf("tokenization failed: %w", err)
}
- // Convert []int32 to []int64 for ONNX
- inputIDs := make([]int64, len(encoding.GetIDs()))
- for i, id := range encoding.GetIDs() {
+ // 2. Convert to int64 and create attention mask
+ ids := encoding.Ids
+ inputIDs := make([]int64, len(ids))
+ attentionMask := make([]int64, len(ids))
+ for i, id := range ids {
inputIDs[i] = int64(id)
+ attentionMask[i] = 1
}
- attentionMask := make([]int64, len(encoding.GetAttentionMask()))
- for i, m := range encoding.GetAttentionMask() {
- attentionMask[i] = int64(m)
- }
- // 2. Create input tensors (shape: [1, seq_len])
+ // 3. Create input tensors (shape: [1, seq_len])
seqLen := int64(len(inputIDs))
- inputIDsTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), inputIDs)
+ inputIDsTensor, err := onnxruntime_go.NewTensor[int64](
+ onnxruntime_go.NewShape(1, seqLen),
+ inputIDs,
+ )
if err != nil {
return nil, fmt.Errorf("failed to create input_ids tensor: %w", err)
}
defer inputIDsTensor.Destroy()
- maskTensor, err := onnxruntime_go.NewTensor(onnxruntime_go.NewShape(1, seqLen), attentionMask)
+ maskTensor, err := onnxruntime_go.NewTensor[int64](
+ onnxruntime_go.NewShape(1, seqLen),
+ attentionMask,
+ )
if err != nil {
return nil, fmt.Errorf("failed to create attention_mask tensor: %w", err)
}
defer maskTensor.Destroy()
- // 3. Create output tensor (shape: [1, dims])
- outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](onnxruntime_go.NewShape(1, int64(e.dims)))
+ // 4. Create output tensor
+ outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
+ onnxruntime_go.NewShape(1, int64(e.dims)),
+ )
if err != nil {
return nil, fmt.Errorf("failed to create output tensor: %w", err)
}
defer outputTensor.Destroy()
- // 4. Run inference
+ // 5. Run inference
err = e.session.Run(
- map[string]*onnxruntime_go.Tensor{
- "input_ids": inputIDsTensor,
- "attention_mask": maskTensor,
- },
+ []onnxruntime_go.Value{inputIDsTensor, maskTensor},
[]string{"sentence_embedding"},
- []*onnxruntime_go.Tensor{outputTensor},
+ []onnxruntime_go.Value{outputTensor},
)
if err != nil {
return nil, fmt.Errorf("inference failed: %w", err)
}
- // 5. Extract data
+ // 6. Copy output data
outputData := outputTensor.GetData()
- // outputTensor is owned by us, but GetData returns a slice that remains valid until Destroy.
- // We need to copy if we want to keep it after Destroy (we defer Destroy, so copy now).
embedding := make([]float32, len(outputData))
copy(embedding, outputData)
return embedding, nil
}
-// EmbedSlice (batch) – to be implemented properly
func (e *ONNXEmbedder) EmbedSlice(texts []string) ([][]float32, error) {
- if len(texts) == 0 {
- return nil, nil
- }
- // 1. Tokenize all texts and find max length for padding
encodings := make([]*tokenizer.Encoding, len(texts))
maxLen := 0
for i, txt := range texts {
- enc, err := e.tokenizer.Encode(txt, true)
+ enc, err := e.tokenizer.EncodeSingle(txt)
if err != nil {
- return nil, fmt.Errorf("tokenization failed at index %d: %w", i, err)
+ return nil, err
}
encodings[i] = enc
- if l := len(enc.GetIDs()); l > maxLen {
+ if l := len(enc.Ids); l > maxLen {
maxLen = l
}
}
- // 2. Build padded input_ids and attention_mask (shape: [batch, maxLen])
batchSize := len(texts)
inputIDs := make([]int64, batchSize*maxLen)
attentionMask := make([]int64, batchSize*maxLen)
for i, enc := range encodings {
- ids := enc.GetIDs()
- mask := enc.GetAttentionMask()
+ ids := enc.Ids
offset := i * maxLen
- // copy actual tokens
- for j := 0; j < len(ids); j++ {
- inputIDs[offset+j] = int64(ids[j])
- attentionMask[offset+j] = int64(mask[j])
+ for j, id := range ids {
+ inputIDs[offset+j] = int64(id)
+ attentionMask[offset+j] = 1
}
- // remaining positions (padding) are already zero-initialized
+ // Remaining positions are already zero (padding)
}
- // 3. Create tensors
- inputIDsTensor, err := onnxruntime_go.NewTensor(
+ // Create tensors with shape [batchSize, maxLen]
+ inputTensor, _ := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
inputIDs,
)
- if err != nil {
- return nil, err
- }
- defer inputIDsTensor.Destroy()
- maskTensor, err := onnxruntime_go.NewTensor(
+ defer inputTensor.Destroy()
+ maskTensor, _ := onnxruntime_go.NewTensor[int64](
onnxruntime_go.NewShape(int64(batchSize), int64(maxLen)),
attentionMask,
)
- if err != nil {
- return nil, err
- }
defer maskTensor.Destroy()
- outputTensor, err := onnxruntime_go.NewEmptyTensor[float32](
+ outputTensor, _ := onnxruntime_go.NewEmptyTensor[float32](
onnxruntime_go.NewShape(int64(batchSize), int64(e.dims)),
)
- if err != nil {
- return nil, err
- }
defer outputTensor.Destroy()
- // 4. Run
- err = e.session.Run(
- map[string]*onnxruntime_go.Tensor{
- "input_ids": inputIDsTensor,
- "attention_mask": maskTensor,
- },
+ err := e.session.Run(
+ []onnxruntime_go.Value{inputTensor, maskTensor},
[]string{"sentence_embedding"},
- []*onnxruntime_go.Tensor{outputTensor},
+ []onnxruntime_go.Value{outputTensor},
)
if err != nil {
return nil, err
}
- // 5. Extract batch results
- outputData := outputTensor.GetData()
+ // Extract embeddings per batch item
+ data := outputTensor.GetData()
embeddings := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
start := i * e.dims
emb := make([]float32, e.dims)
- copy(emb, outputData[start:start+e.dims])
+ copy(emb, data[start:start+e.dims])
embeddings[i] = emb
}
return embeddings, nil