diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-11-22 14:56:24 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-11-22 14:56:24 +0300 |
| commit | 50d7bfced396485f1d313cce11a73c8f386f7956 (patch) | |
| tree | 2e9e2af2ad74d39a48c66be4eb38ebf100ced2af /rag | |
| parent | 5fe03fa66c30f5d7ca6cdf9de1b1cfa2c38d6a45 (diff) | |
Enha: embedgemma model
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/embedder.go | 87 | ||||
| -rw-r--r-- | rag/rag.go | 50 |
2 files changed, 104 insertions, 33 deletions
diff --git a/rag/embedder.go b/rag/embedder.go index 4849941..bed1b41 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -6,14 +6,15 @@ import ( "errors" "fmt" "gf-lt/config" + "gf-lt/models" "log/slog" "net/http" ) // Embedder defines the interface for embedding text type Embedder interface { - Embed(text []string) ([][]float32, error) - EmbedSingle(text string) ([]float32, error) + Embed(text string) ([]float32, error) + EmbedSlice(lines []string) ([][]float32, error) } // APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.) @@ -31,62 +32,107 @@ func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder { } } -func (a *APIEmbedder) Embed(text []string) ([][]float32, error) { +func (a *APIEmbedder) Embed(text string) ([]float32, error) { payload, err := json.Marshal( - map[string]any{"inputs": text, "options": map[string]bool{"wait_for_model": true}}, + map[string]any{"input": text, "encoding_format": "float"}, ) if err != nil { a.logger.Error("failed to marshal payload", "err", err.Error()) return nil, err } - req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) if err != nil { a.logger.Error("failed to create new req", "err", err.Error()) return nil, err } - if a.cfg.HFToken != "" { req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) } - resp, err := a.client.Do(req) if err != nil { a.logger.Error("failed to embed text", "err", err.Error()) return nil, err } defer resp.Body.Close() - if resp.StatusCode != 200 { err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) a.logger.Error(err.Error()) return nil, err } - - var emb [][]float32 - if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil { + embResp := &models.LCPEmbedResp{} + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { a.logger.Error("failed to decode embedding response", "err", err.Error()) return nil, err } - - if len(emb) == 0 { + if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 { err = errors.New("empty embedding response") a.logger.Error("empty embedding response") return nil, err } - - return emb, nil + return embResp.Data[0].Embedding, nil } -func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) { - result, err := a.Embed([]string{text}) +func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { + payload, err := json.Marshal( + map[string]any{"input": lines, "encoding_format": "float"}, + ) if err != nil { + a.logger.Error("failed to marshal payload", "err", err.Error()) return nil, err } - if len(result) == 0 { - return nil, errors.New("no embeddings returned") + req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload)) + if err != nil { + a.logger.Error("failed to create new req", "err", err.Error()) + return nil, err + } + if a.cfg.HFToken != "" { + req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken) } - return result[0], nil + resp, err := a.client.Do(req) + if err != nil { + a.logger.Error("failed to embed text", "err", err.Error()) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode) + a.logger.Error(err.Error()) + return nil, err + } + embResp := &models.LCPEmbedResp{} + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { + a.logger.Error("failed to decode embedding response", "err", err.Error()) + return nil, err + } + if len(embResp.Data) == 0 { + err = errors.New("empty embedding response") + a.logger.Error("empty embedding response") + return nil, err + } + + // Collect all embeddings from the response + embeddings := make([][]float32, len(embResp.Data)) + for i := range embResp.Data { + if len(embResp.Data[i].Embedding) == 0 { + err = fmt.Errorf("empty embedding at index %d", i) + a.logger.Error("empty embedding", "index", i) + return nil, err + } + embeddings[i] = embResp.Data[i].Embedding + } + + // Sort embeddings by index to match the order of input lines + // API responses may not be in order + for _, data := range embResp.Data { + if data.Index >= len(embeddings) || data.Index < 0 { + err = fmt.Errorf("invalid embedding index %d", data.Index) + a.logger.Error("invalid embedding index", "index", data.Index) + return nil, err + } + embeddings[data.Index] = data.Embedding + } + + return embeddings, nil } // TODO: ONNXEmbedder implementation would go here @@ -97,4 +143,3 @@ func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) { // // For now, we'll focus on the API implementation which is already working in the current system, // and can be extended later when we have ONNX runtime integration - @@ -148,10 +148,12 @@ func (r *RAG) LoadRAG(fpath string) error { for w := 0; w < int(r.cfg.RAGWorkers); w++ { go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath)) } - // Wait for embedding to be done <-doneCh - + err = <-errCh + if err != nil { + return err + } // Write vectors to storage return r.writeVectors(vectorCh) } @@ -178,9 +180,11 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error { func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string, vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) { + var err error defer func() { if len(doneCh) == 0 { doneCh <- true + errCh <- err } }() @@ -201,7 +205,7 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in } } lock.Unlock() - case err := <-errCh: + case err = <-errCh: r.logger.Error("got an error from error channel", "error", err) lock.Unlock() return @@ -215,7 +219,23 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in } func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error { - embeddings, err := r.embedder.Embed(lines) + // Filter out empty lines before sending to embedder + nonEmptyLines := make([]string, 0, len(lines)) + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" { + nonEmptyLines = append(nonEmptyLines, trimmed) + } + } + + // Skip if no non-empty lines + if len(nonEmptyLines) == 0 { + // Send empty result but don't error + vectorCh <- []models.VectorRow{} + return nil + } + + embeddings, err := r.embedder.EmbedSlice(nonEmptyLines) if err != nil { r.logger.Error("failed to embed lines", "err", err.Error()) errCh <- err @@ -229,15 +249,22 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model return err } - vectors := make([]models.VectorRow, len(embeddings)) - for i, emb := range embeddings { - vector := models.VectorRow{ - Embeddings: emb, - RawText: lines[i], + if len(embeddings) != len(nonEmptyLines) { + err := errors.New("mismatch between number of lines and embeddings returned") + r.logger.Error("embedding mismatch", "err", err.Error()) + errCh <- err + return err + } + + // Create a VectorRow for each line in the batch + vectors := make([]models.VectorRow, len(nonEmptyLines)) + for i, line := range nonEmptyLines { + vectors[i] = models.VectorRow{ + Embeddings: embeddings[i], + RawText: line, Slug: fmt.Sprintf("%s_%d", slug, i), FileName: filename, } - vectors[i] = vector } vectorCh <- vectors @@ -245,7 +272,7 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model } func (r *RAG) LineToVector(line string) ([]float32, error) { - return r.embedder.EmbedSingle(line) + return r.embedder.Embed(line) } func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { @@ -259,4 +286,3 @@ func (r *RAG) ListLoaded() ([]string, error) { func (r *RAG) RemoveFile(filename string) error { return r.storage.RemoveEmbByFileName(filename) } - |
