summaryrefslogtreecommitdiff
path: root/rag/embedder.go
blob: bed1b412a3b128abe8d0973f951d3be655c73cbd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package rag

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"gf-lt/config"
	"gf-lt/models"
	"log/slog"
	"net/http"
)

// Embedder defines the interface for embedding text
type Embedder interface {
	Embed(text string) ([]float32, error)
	EmbedSlice(lines []string) ([][]float32, error)
}

// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
type APIEmbedder struct {
	logger *slog.Logger
	client *http.Client
	cfg    *config.Config
}

func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
	return &APIEmbedder{
		logger: l,
		client: &http.Client{},
		cfg:    cfg,
	}
}

func (a *APIEmbedder) Embed(text string) ([]float32, error) {
	payload, err := json.Marshal(
		map[string]any{"input": text, "encoding_format": "float"},
	)
	if err != nil {
		a.logger.Error("failed to marshal payload", "err", err.Error())
		return nil, err
	}
	req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
	if err != nil {
		a.logger.Error("failed to create new req", "err", err.Error())
		return nil, err
	}
	if a.cfg.HFToken != "" {
		req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
	}
	resp, err := a.client.Do(req)
	if err != nil {
		a.logger.Error("failed to embed text", "err", err.Error())
		return nil, err
	}
	defer resp.Body.Close()
	if resp.StatusCode != 200 {
		err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
		a.logger.Error(err.Error())
		return nil, err
	}
	embResp := &models.LCPEmbedResp{}
	if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
		a.logger.Error("failed to decode embedding response", "err", err.Error())
		return nil, err
	}
	if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
		err = errors.New("empty embedding response")
		a.logger.Error("empty embedding response")
		return nil, err
	}
	return embResp.Data[0].Embedding, nil
}

func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
	payload, err := json.Marshal(
		map[string]any{"input": lines, "encoding_format": "float"},
	)
	if err != nil {
		a.logger.Error("failed to marshal payload", "err", err.Error())
		return nil, err
	}
	req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
	if err != nil {
		a.logger.Error("failed to create new req", "err", err.Error())
		return nil, err
	}
	if a.cfg.HFToken != "" {
		req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
	}
	resp, err := a.client.Do(req)
	if err != nil {
		a.logger.Error("failed to embed text", "err", err.Error())
		return nil, err
	}
	defer resp.Body.Close()
	if resp.StatusCode != 200 {
		err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
		a.logger.Error(err.Error())
		return nil, err
	}
	embResp := &models.LCPEmbedResp{}
	if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
		a.logger.Error("failed to decode embedding response", "err", err.Error())
		return nil, err
	}
	if len(embResp.Data) == 0 {
		err = errors.New("empty embedding response")
		a.logger.Error("empty embedding response")
		return nil, err
	}

	// Collect all embeddings from the response
	embeddings := make([][]float32, len(embResp.Data))
	for i := range embResp.Data {
		if len(embResp.Data[i].Embedding) == 0 {
			err = fmt.Errorf("empty embedding at index %d", i)
			a.logger.Error("empty embedding", "index", i)
			return nil, err
		}
		embeddings[i] = embResp.Data[i].Embedding
	}

	// Sort embeddings by index to match the order of input lines
	// API responses may not be in order
	for _, data := range embResp.Data {
		if data.Index >= len(embeddings) || data.Index < 0 {
			err = fmt.Errorf("invalid embedding index %d", data.Index)
			a.logger.Error("invalid embedding index", "index", data.Index)
			return nil, err
		}
		embeddings[data.Index] = data.Embedding
	}

	return embeddings, nil
}

// TODO: ONNXEmbedder implementation would go here
// This would require:
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
//
// For now, we'll focus on the API implementation which is already working in the current system,
// and can be extended later when we have ONNX runtime integration