summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-11-22 14:56:24 +0300
committerGrail Finder <wohilas@gmail.com>2025-11-22 14:56:24 +0300
commit50d7bfced396485f1d313cce11a73c8f386f7956 (patch)
tree2e9e2af2ad74d39a48c66be4eb38ebf100ced2af /rag
parent5fe03fa66c30f5d7ca6cdf9de1b1cfa2c38d6a45 (diff)
Enha: embedgemma model
Diffstat (limited to 'rag')
-rw-r--r--rag/embedder.go87
-rw-r--r--rag/rag.go50
2 files changed, 104 insertions, 33 deletions
diff --git a/rag/embedder.go b/rag/embedder.go
index 4849941..bed1b41 100644
--- a/rag/embedder.go
+++ b/rag/embedder.go
@@ -6,14 +6,15 @@ import (
"errors"
"fmt"
"gf-lt/config"
+ "gf-lt/models"
"log/slog"
"net/http"
)
// Embedder defines the interface for embedding text
type Embedder interface {
- Embed(text []string) ([][]float32, error)
- EmbedSingle(text string) ([]float32, error)
+ Embed(text string) ([]float32, error)
+ EmbedSlice(lines []string) ([][]float32, error)
}
// APIEmbedder implements embedder using an API (like Hugging Face, OpenAI, etc.)
@@ -31,62 +32,107 @@ func NewAPIEmbedder(l *slog.Logger, cfg *config.Config) *APIEmbedder {
}
}
-func (a *APIEmbedder) Embed(text []string) ([][]float32, error) {
+func (a *APIEmbedder) Embed(text string) ([]float32, error) {
payload, err := json.Marshal(
- map[string]any{"inputs": text, "options": map[string]bool{"wait_for_model": true}},
+ map[string]any{"input": text, "encoding_format": "float"},
)
if err != nil {
a.logger.Error("failed to marshal payload", "err", err.Error())
return nil, err
}
-
req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
if err != nil {
a.logger.Error("failed to create new req", "err", err.Error())
return nil, err
}
-
if a.cfg.HFToken != "" {
req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
}
-
resp, err := a.client.Do(req)
if err != nil {
a.logger.Error("failed to embed text", "err", err.Error())
return nil, err
}
defer resp.Body.Close()
-
if resp.StatusCode != 200 {
err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
a.logger.Error(err.Error())
return nil, err
}
-
- var emb [][]float32
- if err := json.NewDecoder(resp.Body).Decode(&emb); err != nil {
+ embResp := &models.LCPEmbedResp{}
+ if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
a.logger.Error("failed to decode embedding response", "err", err.Error())
return nil, err
}
-
- if len(emb) == 0 {
+ if len(embResp.Data) == 0 || len(embResp.Data[0].Embedding) == 0 {
err = errors.New("empty embedding response")
a.logger.Error("empty embedding response")
return nil, err
}
-
- return emb, nil
+ return embResp.Data[0].Embedding, nil
}
-func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
- result, err := a.Embed([]string{text})
+func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) {
+ payload, err := json.Marshal(
+ map[string]any{"input": lines, "encoding_format": "float"},
+ )
if err != nil {
+ a.logger.Error("failed to marshal payload", "err", err.Error())
return nil, err
}
- if len(result) == 0 {
- return nil, errors.New("no embeddings returned")
+ req, err := http.NewRequest("POST", a.cfg.EmbedURL, bytes.NewReader(payload))
+ if err != nil {
+ a.logger.Error("failed to create new req", "err", err.Error())
+ return nil, err
+ }
+ if a.cfg.HFToken != "" {
+ req.Header.Add("Authorization", "Bearer "+a.cfg.HFToken)
}
- return result[0], nil
+ resp, err := a.client.Do(req)
+ if err != nil {
+ a.logger.Error("failed to embed text", "err", err.Error())
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ err = fmt.Errorf("non 200 response; code: %v", resp.StatusCode)
+ a.logger.Error(err.Error())
+ return nil, err
+ }
+ embResp := &models.LCPEmbedResp{}
+ if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil {
+ a.logger.Error("failed to decode embedding response", "err", err.Error())
+ return nil, err
+ }
+ if len(embResp.Data) == 0 {
+ err = errors.New("empty embedding response")
+ a.logger.Error("empty embedding response")
+ return nil, err
+ }
+
+ // Collect all embeddings from the response
+ embeddings := make([][]float32, len(embResp.Data))
+ for i := range embResp.Data {
+ if len(embResp.Data[i].Embedding) == 0 {
+ err = fmt.Errorf("empty embedding at index %d", i)
+ a.logger.Error("empty embedding", "index", i)
+ return nil, err
+ }
+ embeddings[i] = embResp.Data[i].Embedding
+ }
+
+ // Sort embeddings by index to match the order of input lines
+ // API responses may not be in order
+ for _, data := range embResp.Data {
+ if data.Index >= len(embeddings) || data.Index < 0 {
+ err = fmt.Errorf("invalid embedding index %d", data.Index)
+ a.logger.Error("invalid embedding index", "index", data.Index)
+ return nil, err
+ }
+ embeddings[data.Index] = data.Embedding
+ }
+
+ return embeddings, nil
}
// TODO: ONNXEmbedder implementation would go here
@@ -97,4 +143,3 @@ func (a *APIEmbedder) EmbedSingle(text string) ([]float32, error) {
//
// For now, we'll focus on the API implementation which is already working in the current system,
// and can be extended later when we have ONNX runtime integration
-
diff --git a/rag/rag.go b/rag/rag.go
index 018cd9a..c560f33 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -148,10 +148,12 @@ func (r *RAG) LoadRAG(fpath string) error {
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
}
-
// Wait for embedding to be done
<-doneCh
-
+ err = <-errCh
+ if err != nil {
+ return err
+ }
// Write vectors to storage
return r.writeVectors(vectorCh)
}
@@ -178,9 +180,11 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
+ var err error
defer func() {
if len(doneCh) == 0 {
doneCh <- true
+ errCh <- err
}
}()
@@ -201,7 +205,7 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
}
}
lock.Unlock()
- case err := <-errCh:
+ case err = <-errCh:
r.logger.Error("got an error from error channel", "error", err)
lock.Unlock()
return
@@ -215,7 +219,23 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
}
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
- embeddings, err := r.embedder.Embed(lines)
+ // Filter out empty lines before sending to embedder
+ nonEmptyLines := make([]string, 0, len(lines))
+ for _, line := range lines {
+ trimmed := strings.TrimSpace(line)
+ if trimmed != "" {
+ nonEmptyLines = append(nonEmptyLines, trimmed)
+ }
+ }
+
+ // Skip if no non-empty lines
+ if len(nonEmptyLines) == 0 {
+ // Send empty result but don't error
+ vectorCh <- []models.VectorRow{}
+ return nil
+ }
+
+ embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
if err != nil {
r.logger.Error("failed to embed lines", "err", err.Error())
errCh <- err
@@ -229,15 +249,22 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
return err
}
- vectors := make([]models.VectorRow, len(embeddings))
- for i, emb := range embeddings {
- vector := models.VectorRow{
- Embeddings: emb,
- RawText: lines[i],
+ if len(embeddings) != len(nonEmptyLines) {
+ err := errors.New("mismatch between number of lines and embeddings returned")
+ r.logger.Error("embedding mismatch", "err", err.Error())
+ errCh <- err
+ return err
+ }
+
+ // Create a VectorRow for each line in the batch
+ vectors := make([]models.VectorRow, len(nonEmptyLines))
+ for i, line := range nonEmptyLines {
+ vectors[i] = models.VectorRow{
+ Embeddings: embeddings[i],
+ RawText: line,
Slug: fmt.Sprintf("%s_%d", slug, i),
FileName: filename,
}
- vectors[i] = vector
}
vectorCh <- vectors
@@ -245,7 +272,7 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
}
func (r *RAG) LineToVector(line string) ([]float32, error) {
- return r.embedder.EmbedSingle(line)
+ return r.embedder.Embed(line)
}
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
@@ -259,4 +286,3 @@ func (r *RAG) ListLoaded() ([]string, error) {
func (r *RAG) RemoveFile(filename string) error {
return r.storage.RemoveEmbByFileName(filename)
}
-