summaryrefslogtreecommitdiff
path: root/rag/rag.go
diff options
context:
space:
mode:
Diffstat (limited to 'rag/rag.go')
-rw-r--r--rag/rag.go50
1 files changed, 38 insertions, 12 deletions
diff --git a/rag/rag.go b/rag/rag.go
index 018cd9a..c560f33 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -148,10 +148,12 @@ func (r *RAG) LoadRAG(fpath string) error {
for w := 0; w < int(r.cfg.RAGWorkers); w++ {
go r.batchToVectorAsync(lock, w, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
}
-
// Wait for embedding to be done
<-doneCh
-
+ err = <-errCh
+ if err != nil {
+ return err
+ }
// Write vectors to storage
return r.writeVectors(vectorCh)
}
@@ -178,9 +180,11 @@ func (r *RAG) writeVectors(vectorCh chan []models.VectorRow) error {
func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[int][]string,
vectorCh chan<- []models.VectorRow, errCh chan error, doneCh chan bool, filename string) {
+ var err error
defer func() {
if len(doneCh) == 0 {
doneCh <- true
+ errCh <- err
}
}()
@@ -201,7 +205,7 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
}
}
lock.Unlock()
- case err := <-errCh:
+ case err = <-errCh:
r.logger.Error("got an error from error channel", "error", err)
lock.Unlock()
return
@@ -215,7 +219,23 @@ func (r *RAG) batchToVectorAsync(lock *sync.Mutex, id int, inputCh <-chan map[in
}
func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
- embeddings, err := r.embedder.Embed(lines)
+ // Filter out empty lines before sending to embedder
+ nonEmptyLines := make([]string, 0, len(lines))
+ for _, line := range lines {
+ trimmed := strings.TrimSpace(line)
+ if trimmed != "" {
+ nonEmptyLines = append(nonEmptyLines, trimmed)
+ }
+ }
+
+ // Skip if no non-empty lines
+ if len(nonEmptyLines) == 0 {
+ // Send empty result but don't error
+ vectorCh <- []models.VectorRow{}
+ return nil
+ }
+
+ embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
if err != nil {
r.logger.Error("failed to embed lines", "err", err.Error())
errCh <- err
@@ -229,15 +249,22 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
return err
}
- vectors := make([]models.VectorRow, len(embeddings))
- for i, emb := range embeddings {
- vector := models.VectorRow{
- Embeddings: emb,
- RawText: lines[i],
+ if len(embeddings) != len(nonEmptyLines) {
+ err := errors.New("mismatch between number of lines and embeddings returned")
+ r.logger.Error("embedding mismatch", "err", err.Error())
+ errCh <- err
+ return err
+ }
+
+ // Create a VectorRow for each line in the batch
+ vectors := make([]models.VectorRow, len(nonEmptyLines))
+ for i, line := range nonEmptyLines {
+ vectors[i] = models.VectorRow{
+ Embeddings: embeddings[i],
+ RawText: line,
Slug: fmt.Sprintf("%s_%d", slug, i),
FileName: filename,
}
- vectors[i] = vector
}
vectorCh <- vectors
@@ -245,7 +272,7 @@ func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []model
}
func (r *RAG) LineToVector(line string) ([]float32, error) {
- return r.embedder.EmbedSingle(line)
+ return r.embedder.Embed(line)
}
func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
@@ -259,4 +286,3 @@ func (r *RAG) ListLoaded() ([]string, error) {
func (r *RAG) RemoveFile(filename string) error {
return r.storage.RemoveEmbByFileName(filename)
}
-