summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rag/rag.go250
-rw-r--r--tables.go3
2 files changed, 57 insertions, 196 deletions
diff --git a/rag/rag.go b/rag/rag.go
index 71c4ce8..f554924 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -29,6 +29,7 @@ type RAG struct {
cfg *config.Config
embedder Embedder
storage *VectorStorage
+ mu sync.Mutex
}
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
@@ -53,6 +54,8 @@ func wordCounter(sentence string) int {
}
func (r *RAG) LoadRAG(fpath string) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
data, err := os.ReadFile(fpath)
if err != nil {
return err
@@ -62,9 +65,7 @@ func (r *RAG) LoadRAG(fpath string) error {
case LongJobStatusCh <- LoadedFileRAGStatus:
default:
r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
- // Channel is full or closed, ignore the message to prevent panic
}
-
fileText := string(data)
tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil {
@@ -75,19 +76,16 @@ func (r *RAG) LoadRAG(fpath string) error {
for i, s := range sentences {
sents[i] = s.Text
}
-
// Group sentences into paragraphs based on word limit
paragraphs := []string{}
par := strings.Builder{}
for i := 0; i < len(sents); i++ {
- // Only add sentences that aren't empty
if strings.TrimSpace(sents[i]) != "" {
if par.Len() > 0 {
- par.WriteString(" ") // Add space between sentences
+ par.WriteString(" ")
}
par.WriteString(sents[i])
}
-
if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
paragraph := strings.TrimSpace(par.String())
if paragraph != "" {
@@ -96,7 +94,6 @@ func (r *RAG) LoadRAG(fpath string) error {
par.Reset()
}
}
-
// Handle any remaining content in the paragraph buffer
if par.Len() > 0 {
paragraph := strings.TrimSpace(par.String())
@@ -104,217 +101,82 @@ func (r *RAG) LoadRAG(fpath string) error {
paragraphs = append(paragraphs, paragraph)
}
}
-
// Adjust batch size if needed
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
r.cfg.RAGBatchSize = len(paragraphs)
}
-
if len(paragraphs) == 0 {
return errors.New("no valid paragraphs found in file")
}
-
- var (
- maxChSize = 100
- left = 0
- right = r.cfg.RAGBatchSize
- batchCh = make(chan map[int][]string, maxChSize)
- vectorCh = make(chan []models.VectorRow, maxChSize)
- errCh = make(chan error, 1)
- doneCh = make(chan struct{})
- wg = new(sync.WaitGroup)
- )
-
- defer close(doneCh)
- defer close(errCh)
- defer close(batchCh)
-
- // Fill input channel with batches
- ctn := 0
- totalParagraphs := len(paragraphs)
- for {
- if right > totalParagraphs {
- batchCh <- map[int][]string{left: paragraphs[left:]}
- break
+ // Process paragraphs in batches synchronously
+ batchCount := 0
+ for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
+ end := i + r.cfg.RAGBatchSize
+ if end > len(paragraphs) {
+ end = len(paragraphs)
}
- batchCh <- map[int][]string{left: paragraphs[left:right]}
- left, right = right, right+r.cfg.RAGBatchSize
- ctn++
- }
-
- finishedBatchesMsg := fmt.Sprintf("finished batching batches#: %d; paragraphs: %d; sentences: %d\n", ctn+1, len(paragraphs), len(sents))
- r.logger.Debug(finishedBatchesMsg)
- select {
- case LongJobStatusCh <- finishedBatchesMsg:
- default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", finishedBatchesMsg)
- // Channel is full or closed, ignore the message to prevent panic
- }
-
- // Start worker goroutines with WaitGroup
- wg.Add(int(r.cfg.RAGWorkers))
- for w := 0; w < int(r.cfg.RAGWorkers); w++ {
- go func(workerID int) {
- defer wg.Done()
- r.batchToVectorAsync(workerID, batchCh, vectorCh, errCh, doneCh, path.Base(fpath))
- }(w)
- }
-
- // Close batchCh to signal workers no more data is coming
- close(batchCh)
-
- // Wait for all workers to finish, then close vectorCh
- go func() {
- wg.Wait()
- close(vectorCh)
- }()
-
- // Check for errors from workers - this will block until an error occurs or all workers finish
- select {
- case err := <-errCh:
- if err != nil {
- r.logger.Error("error during RAG processing", "error", err)
- return err
- }
- default:
- // No immediate error, continue
- }
-
- // Write vectors to storage - this will block until vectorCh is closed
- return r.writeVectors(vectorCh, errCh)
-}
-
-func (r *RAG) writeVectors(vectorCh chan []models.VectorRow, errCh chan error) error {
- // Use a select to handle both vectorCh and errCh
- for {
- select {
- case err := <-errCh:
- if err != nil {
- r.logger.Error("error during RAG processing in writeVectors", "error", err)
- return err
- }
- case batch, ok := <-vectorCh:
- if !ok {
- r.logger.Debug("vector channel closed, finished writing vectors")
- select {
- case LongJobStatusCh <- FinishedRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
- }
- return nil
- }
- for _, vector := range batch {
- if err := r.storage.WriteVector(&vector); err != nil {
- r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
- select {
- case LongJobStatusCh <- ErrRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", ErrRAGStatus)
- }
- return err
- }
+ batch := paragraphs[i:end]
+ batchCount++
+ // Filter empty paragraphs
+ nonEmptyBatch := make([]string, 0, len(batch))
+ for _, p := range batch {
+ if strings.TrimSpace(p) != "" {
+ nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
}
- r.logger.Debug("wrote batch to db", "size", len(batch))
}
- }
-}
-
-func (r *RAG) batchToVectorAsync(id int, inputCh <-chan map[int][]string,
- vectorCh chan<- []models.VectorRow, errCh chan error, doneCh <-chan struct{}, filename string) {
- var err error
-
- defer func() {
+ if len(nonEmptyBatch) == 0 {
+ continue
+ }
+ // Embed the batch
+ embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
if err != nil {
+ r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
select {
- case errCh <- err:
+ case LongJobStatusCh <- ErrRAGStatus:
default:
- r.logger.Warn("errCh channel full or closed, skipping error propagation", "worker", id, "error", err)
+ r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
+ return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
}
- }()
-
- for {
- select {
- case <-doneCh:
- r.logger.Debug("worker received done signal", "worker", id)
- return
- case linesMap, ok := <-inputCh:
- if !ok {
- r.logger.Debug("input channel closed, worker exiting", "worker", id)
- return
+ if len(embeddings) != len(nonEmptyBatch) {
+ err := errors.New("embedding count mismatch")
+ r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
+ return err
+ }
+ // Write vectors to storage
+ filename := path.Base(fpath)
+ for j, text := range nonEmptyBatch {
+ vector := models.VectorRow{
+ Embeddings: embeddings[j],
+ RawText: text,
+ Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
+ FileName: filename,
}
- for leftI, lines := range linesMap {
+ if err := r.storage.WriteVector(&vector); err != nil {
+ r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
select {
- case <-doneCh:
- return
+ case LongJobStatusCh <- ErrRAGStatus:
default:
+ r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
- if err := r.fetchEmb(lines, errCh, vectorCh, fmt.Sprintf("%s_%d", filename, leftI), filename); err != nil {
- r.logger.Error("error fetching embeddings", "error", err, "worker", id)
- return
- }
- }
- r.logger.Debug("processed batch", "worker#", id)
- statusMsg := fmt.Sprintf("converted to vector; worker#: %d", id)
- select {
- case LongJobStatusCh <- statusMsg:
- default:
- r.logger.Warn("LongJobStatusCh channel full or closed, dropping status message", "message", statusMsg)
+ return fmt.Errorf("failed to write vector: %w", err)
}
}
- }
-}
-
-func (r *RAG) fetchEmb(lines []string, errCh chan error, vectorCh chan<- []models.VectorRow, slug, filename string) error {
- // Filter out empty lines before sending to embedder
- nonEmptyLines := make([]string, 0, len(lines))
- for _, line := range lines {
- trimmed := strings.TrimSpace(line)
- if trimmed != "" {
- nonEmptyLines = append(nonEmptyLines, trimmed)
+ r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
+ // Send progress status
+ statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
+ select {
+ case LongJobStatusCh <- statusMsg:
+ default:
+ r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
}
-
- // Skip if no non-empty lines
- if len(nonEmptyLines) == 0 {
- // Send empty result but don't error
- vectorCh <- []models.VectorRow{}
- return nil
- }
-
- embeddings, err := r.embedder.EmbedSlice(nonEmptyLines)
- if err != nil {
- r.logger.Error("failed to embed lines", "err", err.Error())
- errCh <- err
- return err
- }
-
- if len(embeddings) == 0 {
- err := errors.New("no embeddings returned")
- r.logger.Error("empty embeddings")
- errCh <- err
- return err
- }
-
- if len(embeddings) != len(nonEmptyLines) {
- err := errors.New("mismatch between number of lines and embeddings returned")
- r.logger.Error("embedding mismatch", "err", err.Error())
- errCh <- err
- return err
- }
-
- // Create a VectorRow for each line in the batch
- vectors := make([]models.VectorRow, len(nonEmptyLines))
- for i, line := range nonEmptyLines {
- vectors[i] = models.VectorRow{
- Embeddings: embeddings[i],
- RawText: line,
- Slug: fmt.Sprintf("%s_%d", slug, i),
- FileName: filename,
- }
+ r.logger.Debug("finished writing vectors", "batches", batchCount)
+ select {
+ case LongJobStatusCh <- FinishedRAGStatus:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
}
-
- vectorCh <- vectors
return nil
}
diff --git a/tables.go b/tables.go
index b1ec128..23000f4 100644
--- a/tables.go
+++ b/tables.go
@@ -387,8 +387,7 @@ func makeRAGTable(fileList []string) *tview.Flex {
if err := ragger.LoadRAG(fpath); err != nil {
logger.Error("failed to embed file", "chat", fpath, "error", err)
_ = notifyUser("RAG", "failed to embed file; error: "+err.Error())
- errCh <- err
- // pages.RemovePage(RAGPage)
+ pages.RemovePage(RAGPage)
return
}
}()