summaryrefslogtreecommitdiff
path: root/rag/rag.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-06 18:58:23 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-06 18:58:23 +0300
commit17b68bc21fae99c17ec48e046e67a643b9d159bb (patch)
tree00b2da2f55876e720aecccc10dbc59232da768db /rag/rag.go
parentedfd43c52ae3f2fa16f6ab5d64cb48218a2c0a64 (diff)
Enha (rag): async writes
Diffstat (limited to 'rag/rag.go')
-rw-r--r--rag/rag.go440
1 files changed, 373 insertions, 67 deletions
diff --git a/rag/rag.go b/rag/rag.go
index 9271b60..180ad50 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -1,6 +1,7 @@
package rag
import (
+ "context"
"errors"
"fmt"
"gf-lt/config"
@@ -9,6 +10,7 @@ import (
"log/slog"
"path"
"regexp"
+ "runtime"
"sort"
"strings"
"sync"
@@ -17,9 +19,14 @@ import (
"github.com/neurosnap/sentences/english"
)
+const (
+ // batchTimeout is the maximum time allowed for embedding a single batch
+ batchTimeout = 2 * time.Minute
+)
+
var (
// Status messages for TUI integration
- LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking
+ LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates
FinishedRAGStatus = "finished loading RAG file; press Enter"
LoadedFileRAGStatus = "loaded file"
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
@@ -31,12 +38,38 @@ type RAG struct {
cfg *config.Config
embedder Embedder
storage *VectorStorage
- mu sync.Mutex
+ mu sync.RWMutex
+ idleMu sync.Mutex
fallbackMsg string
idleTimer *time.Timer
idleTimeout time.Duration
}
+// batchTask represents a single batch to be embedded
+type batchTask struct {
+ batchIndex int
+ paragraphs []string
+ filename string
+ totalBatches int
+}
+
+// batchResult represents the result of embedding a batch
+type batchResult struct {
+ batchIndex int
+ embeddings [][]float32
+ paragraphs []string
+ filename string
+}
+
+// sendStatusNonBlocking sends a status message without blocking
+func (r *RAG) sendStatusNonBlocking(status string) {
+ select {
+ case LongJobStatusCh <- status:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status)
+ }
+}
+
func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
var embedder Embedder
var fallbackMsg string
@@ -142,18 +175,22 @@ func sanitizeFTSQuery(query string) string {
}
func (r *RAG) LoadRAG(fpath string) error {
+ return r.LoadRAGWithContext(context.Background(), fpath)
+}
+
+func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
r.mu.Lock()
defer r.mu.Unlock()
+
fileText, err := ExtractText(fpath)
if err != nil {
return err
}
r.logger.Debug("rag: loaded file", "fp", fpath)
- select {
- case LongJobStatusCh <- LoadedFileRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
- }
+
+ // Send initial status (non-blocking with retry)
+ r.sendStatusNonBlocking(LoadedFileRAGStatus)
+
tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil {
return err
@@ -163,6 +200,7 @@ func (r *RAG) LoadRAG(fpath string) error {
for i, s := range sentences {
sents[i] = s.Text
}
+
// Create chunks with overlap
paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
// Adjust batch size if needed
@@ -172,76 +210,332 @@ func (r *RAG) LoadRAG(fpath string) error {
if len(paragraphs) == 0 {
return errors.New("no valid paragraphs found in file")
}
- // Process paragraphs in batches synchronously
- batchCount := 0
- for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
- end := i + r.cfg.RAGBatchSize
- if end > len(paragraphs) {
- end = len(paragraphs)
- }
- batch := paragraphs[i:end]
- batchCount++
- // Filter empty paragraphs
- nonEmptyBatch := make([]string, 0, len(batch))
- for _, p := range batch {
- if strings.TrimSpace(p) != "" {
- nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
+
+ totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
+ r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
+
+ // Determine concurrency level
+ concurrency := runtime.NumCPU()
+ if concurrency > totalBatches {
+ concurrency = totalBatches
+ }
+ if concurrency < 1 {
+ concurrency = 1
+ }
+ // If using ONNX embedder, limit concurrency to 1 due to mutex serialization
+ isONNX := false
+ if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
+ concurrency = 1
+ }
+ embedderType := "API"
+ if isONNX {
+ embedderType = "ONNX"
+ }
+ r.logger.Debug("parallel embedding setup",
+ "total_batches", totalBatches,
+ "concurrency", concurrency,
+ "embedder", embedderType,
+ "batch_size", r.cfg.RAGBatchSize)
+
+ // Create context with timeout (30 minutes) and cancellation for error handling
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
+ defer cancel()
+
+ // Channels for task distribution and results
+ taskCh := make(chan batchTask, totalBatches)
+ resultCh := make(chan batchResult, totalBatches)
+ errorCh := make(chan error, totalBatches)
+
+ // Start worker goroutines
+ var wg sync.WaitGroup
+ for w := 0; w < concurrency; w++ {
+ wg.Add(1)
+ go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg)
+ }
+
+ // Close task channel after all tasks are sent (by separate goroutine)
+ go func() {
+ // Ensure task channel is closed when this goroutine exits
+ defer close(taskCh)
+ r.logger.Debug("task distributor started", "total_batches", totalBatches)
+
+ for i := 0; i < totalBatches; i++ {
+ start := i * r.cfg.RAGBatchSize
+ end := start + r.cfg.RAGBatchSize
+ if end > len(paragraphs) {
+ end = len(paragraphs)
}
- }
- if len(nonEmptyBatch) == 0 {
- continue
- }
- // Embed the batch
- embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
- if err != nil {
- r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
+ batch := paragraphs[start:end]
+
+ // Filter empty paragraphs
+ nonEmptyBatch := make([]string, 0, len(batch))
+ for _, p := range batch {
+ if strings.TrimSpace(p) != "" {
+ nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
+ }
+ }
+
+ task := batchTask{
+ batchIndex: i,
+ paragraphs: nonEmptyBatch,
+ filename: path.Base(fpath),
+ totalBatches: totalBatches,
+ }
+
select {
- case LongJobStatusCh <- ErrRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel full, dropping message")
+ case taskCh <- task:
+ r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch))
+ case <-ctx.Done():
+ r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches)
+ return
}
- return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
- }
- if len(embeddings) != len(nonEmptyBatch) {
- err := errors.New("embedding count mismatch")
- r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
- return err
- }
- // Write vectors to storage
- filename := path.Base(fpath)
- for j, text := range nonEmptyBatch {
- vector := models.VectorRow{
- Embeddings: embeddings[j],
- RawText: text,
- Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
- FileName: filename,
+ }
+ r.logger.Debug("task distributor finished", "batches_sent", totalBatches)
+ }()
+
+ // Wait for workers to finish and close result channel
+ go func() {
+ wg.Wait()
+ close(resultCh)
+ }()
+
+ // Process results in order and write to database
+ nextExpectedBatch := 0
+ resultsBuffer := make(map[int]batchResult)
+ filename := path.Base(fpath)
+ batchesProcessed := 0
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+
+ case err := <-errorCh:
+ // First error from any worker, cancel everything
+ cancel()
+ r.logger.Error("embedding worker failed", "error", err)
+ r.sendStatusNonBlocking(ErrRAGStatus)
+ return fmt.Errorf("embedding failed: %w", err)
+
+ case result, ok := <-resultCh:
+ if !ok {
+ // All results processed
+ resultCh = nil
+ r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches)
+ continue
+ }
+
+ // Store result in buffer
+ resultsBuffer[result.batchIndex] = result
+
+ // Process buffered results in order
+ for {
+ if res, exists := resultsBuffer[nextExpectedBatch]; exists {
+ // Write this batch to database
+ if err := r.writeBatchToStorage(ctx, res, filename); err != nil {
+ cancel()
+ return err
+ }
+
+ batchesProcessed++
+ // Send progress update
+ statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches)
+ r.sendStatusNonBlocking(statusMsg)
+
+ delete(resultsBuffer, nextExpectedBatch)
+ nextExpectedBatch++
+ } else {
+ break
+ }
}
- if err := r.storage.WriteVector(&vector); err != nil {
- r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
+
+ default:
+ // No channels ready, check for deadlock conditions
+ if resultCh == nil && nextExpectedBatch < totalBatches {
+ // Missing batch results after result channel closed
+ r.logger.Error("missing batch results",
+ "expected", totalBatches,
+ "received", nextExpectedBatch,
+ "missing", totalBatches-nextExpectedBatch)
+
+ // Wait a short time for any delayed errors, then cancel
select {
- case LongJobStatusCh <- ErrRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel full, dropping message")
+ case <-time.After(5 * time.Second):
+ cancel()
+ return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch)
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-errorCh:
+ cancel()
+ r.logger.Error("embedding worker failed after result channel closed", "error", err)
+ r.sendStatusNonBlocking(ErrRAGStatus)
+ return fmt.Errorf("embedding failed: %w", err)
}
- return fmt.Errorf("failed to write vector: %w", err)
+ }
+ // If we reach here, no deadlock yet, just busy loop prevention
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ // Check if we're done
+ if resultCh == nil && nextExpectedBatch >= totalBatches {
+ r.logger.Debug("all batches processed successfully", "total", totalBatches)
+ break
+ }
+ }
+
+ r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
+ r.resetIdleTimer()
+ r.sendStatusNonBlocking(FinishedRAGStatus)
+ return nil
+}
+
+// embeddingWorker processes batch embedding tasks
+func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) {
+ defer wg.Done()
+ r.logger.Debug("embedding worker started", "worker", workerID)
+
+ // Panic recovery to ensure worker doesn't crash silently
+ defer func() {
+ if rec := recover(); rec != nil {
+ r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec)
+ // Try to send error, but don't block if channel is full
+ select {
+ case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec):
+ default:
+ r.logger.Warn("error channel full, dropping panic error", "worker", workerID)
}
}
- r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
- // Send progress status
- statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
+ }()
+
+ for task := range taskCh {
select {
- case LongJobStatusCh <- statusMsg:
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled", "worker", workerID)
+ return
default:
- r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
+ r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches)
+
+ // Skip empty batches
+ if len(task.paragraphs) == 0 {
+ select {
+ case resultCh <- batchResult{
+ batchIndex: task.batchIndex,
+ embeddings: nil,
+ paragraphs: nil,
+ filename: task.filename,
+ }:
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID)
+ return
+ }
+ r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
+ continue
+ }
+
+ // Embed with retry for API embedder
+ embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
+ if err != nil {
+ // Try to send error, but don't block indefinitely
+ select {
+ case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err):
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID)
+ }
+ return
+ }
+
+ // Send result with context awareness
+ select {
+ case resultCh <- batchResult{
+ batchIndex: task.batchIndex,
+ embeddings: embeddings,
+ paragraphs: task.paragraphs,
+ filename: task.filename,
+ }:
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID)
+ return
+ }
+ r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings))
}
- r.logger.Debug("finished writing vectors", "batches", batchCount)
- r.resetIdleTimer()
+ r.logger.Debug("embedding worker finished", "worker", workerID)
+}
+
+// embedWithRetry attempts embedding with exponential backoff for API embedder
+func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
+ var lastErr error
+
+ for attempt := 0; attempt < maxRetries; attempt++ {
+ if attempt > 0 {
+ // Exponential backoff
+ backoff := time.Duration(attempt*attempt) * time.Second
+ if backoff > 10*time.Second {
+ backoff = 10 * time.Second
+ }
+
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+
+ r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
+ }
+
+ embeddings, err := r.embedder.EmbedSlice(paragraphs)
+ if err == nil {
+ // Validate embedding count
+ if len(embeddings) != len(paragraphs) {
+ return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings))
+ }
+ return embeddings, nil
+ }
+
+ lastErr = err
+ // Only retry for API embedder errors (network/timeout)
+ // For ONNX embedder, fail fast
+ if _, isAPI := r.embedder.(*APIEmbedder); !isAPI {
+ break
+ }
+ }
+
+ return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
+}
+
+// writeBatchToStorage writes a single batch of vectors to the database
+func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error {
+ if len(result.embeddings) == 0 {
+ // Empty batch, skip
+ return nil
+ }
+
+ // Check context before starting
select {
- case LongJobStatusCh <- FinishedRAGStatus:
+ case <-ctx.Done():
+ return ctx.Err()
default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
}
+
+ // Build all vectors for batch write
+ vectors := make([]*models.VectorRow, 0, len(result.paragraphs))
+ for j, text := range result.paragraphs {
+ vectors = append(vectors, &models.VectorRow{
+ Embeddings: result.embeddings[j],
+ RawText: text,
+ Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j),
+ FileName: filename,
+ })
+ }
+
+ // Write all vectors in a single transaction
+ if err := r.storage.WriteVectors(vectors); err != nil {
+ r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors))
+ r.sendStatusNonBlocking(ErrRAGStatus)
+ return fmt.Errorf("failed to write vectors batch: %w", err)
+ }
+
+ r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
return nil
}
@@ -250,22 +544,26 @@ func (r *RAG) LineToVector(line string) ([]float32, error) {
return r.embedder.Embed(line)
}
-func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
+func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
return r.storage.SearchClosest(emb.Embedding, limit)
}
-func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) {
+func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) {
r.resetIdleTimer()
sanitized := sanitizeFTSQuery(query)
return r.storage.SearchKeyword(sanitized, limit)
}
func (r *RAG) ListLoaded() ([]string, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
return r.storage.ListFiles()
}
func (r *RAG) RemoveFile(filename string) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
r.resetIdleTimer()
return r.storage.RemoveEmbByFileName(filename)
}
@@ -454,6 +752,9 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
}
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.resetIdleTimer()
if len(results) == 0 {
return "No relevant information found in the vector database.", nil
}
@@ -482,7 +783,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
Embedding: emb,
Index: 0,
}
- topResults, err := r.SearchEmb(embResp, 1)
+ topResults, err := r.searchEmb(embResp, 1)
if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err)
return "", err
@@ -509,6 +810,9 @@ func truncateString(s string, maxLen int) string {
}
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.resetIdleTimer()
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
@@ -525,7 +829,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
Embedding: emb,
Index: 0,
}
- results, err := r.SearchEmb(embResp, limit*2) // Get more candidates
+ results, err := r.searchEmb(embResp, limit*2) // Get more candidates
if err != nil {
r.logger.Error("failed to search embeddings", "error", err, "query", q)
continue
@@ -543,7 +847,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
})
// Perform keyword search
- kwResults, err := r.SearchKeyword(refined, limit*2)
+ kwResults, err := r.searchKeyword(refined, limit*2)
if err != nil {
r.logger.Warn("keyword search failed, using only embeddings", "error", err)
kwResults = nil
@@ -621,6 +925,8 @@ func GetInstance() *RAG {
}
func (r *RAG) resetIdleTimer() {
+ r.idleMu.Lock()
+ defer r.idleMu.Unlock()
if r.idleTimer != nil {
r.idleTimer.Stop()
}