summaryrefslogtreecommitdiff
path: root/rag/embedder.go
blob: 48499417863958bb8ff9c40fb173fd144e1198be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package rag

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"gf-lt/config"
	"log/slog"
	"net/http"
)

// Embedder defines the interface for embedding text
type Embedder interface {
	Embed(text []string) ([][]float32, error)
	EmbedSingle(text string) ([]float32, error)
}

// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
type APIEmbedder struct {
	logger *slog.Logger
	client *http.Client
	cfg    *config.Config
}

func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
	return &APIEmbedder{
		logger: l,
		client: &http.Client{},
		cfg:    cfg,
	}
}

func (a *APIEmbedder) Embed(text []string) ([][]float32, error) {
	payload, err := json.Marshal(
		map[string]any{"inputs": text, "options": map[string]bool{"wait_for_model": true}},
	)
	if err != nil {
		a.logger.Error("failed to marshal payload", "err", err.Error())
		return nil, err
	}

	req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
	if err != nil {
		a.logger.Error("failed to create new req", "err", err.Error())
		return nil, err
	}

	if a.cfg.HFToken != "" {
		req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
	}

	resp, err := a.client.Do(req)
	if err != nil {
		a.logger.Error("failed to embed text", "err", err.Error())
		return nil, err
	}
	defer resp.Body.Close()

	if resp.StatusCode != 200 {
		err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
		a.logger.Error(err.Error())
		return nil, err
	}

	var emb [][]float32
	if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
		a.logger.Error("failed to decode embedding response", "err", err.Error())
		return nil, err
	}

	if len(emb) == 0 {
		err = errors.New("empty embedding response")
		a.logger.Error("empty embedding response")
		return nil, err
	}

	return emb, nil
}

func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
	result, err := a.Embed([]string{text})
	if err != nil {
		return nil, err
	}
	if len(result) == 0 {
		return nil, errors.New("no embeddings returned")
	}
	return result[0], nil
}

// TODO: ONNXEmbedder implementation would go here
// This would require:
// 1. Loading ONNX models locally
// 2. Using a Go ONNX runtime (like gorgonia/onnx or similar)
// 3. Converting text to embeddings without external API calls
//
// For now, we'll focus on the API implementation which is already working in the current system,
// and can be extended later when we have ONNX runtime integration