summaryrefslogtreecommitdiff
path: root/rag/rag.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-09 07:07:36 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-09 07:07:36 +0300
commit0e42a6f069ceea40485162c014c04cf718568cfe (patch)
tree583a6a6cb91b315e506990a03fdda1b32d0fe985 /rag/rag.go
parent2687f38d00ceaa4f61034e3e02b9b59d08efc017 (diff)
parenta1b5f9cdc59938901123650fc0900067ac3447ca (diff)
Merge branch 'master' into feat/agent-flow
Diffstat (limited to 'rag/rag.go')
-rw-r--r--rag/rag.go1001
1 files changed, 873 insertions, 128 deletions
diff --git a/rag/rag.go b/rag/rag.go
index b63cb08..3a771d4 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -1,6 +1,7 @@
package rag
import (
+ "context"
"errors"
"fmt"
"gf-lt/config"
@@ -9,51 +10,281 @@ import (
"log/slog"
"path"
"regexp"
+ "runtime"
"sort"
+ "strconv"
"strings"
"sync"
+ "time"
"github.com/neurosnap/sentences/english"
)
+const ()
+
var (
// Status messages for TUI integration
- LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking
- FinishedRAGStatus = "finished loading RAG file; press Enter"
+ LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates
+ FinishedRAGStatus = "finished loading RAG file; press x to exit"
LoadedFileRAGStatus = "loaded file"
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
+
+ // stopWords are common words that can be removed from queries when not part of phrases
+ stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"}
)
+// isStopWord checks if a word is in the stop words list
+func isStopWord(word string) bool {
+ for _, stop := range stopWords {
+ if strings.EqualFold(word, stop) {
+ return true
+ }
+ }
+ return false
+}
+
+// detectPhrases returns multi-word phrases from a query that should be treated as units
+func detectPhrases(query string) []string {
+ words := strings.Fields(strings.ToLower(query))
+ var phrases []string
+
+ for i := 0; i < len(words)-1; i++ {
+ word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}")
+ word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}")
+
+ // Skip if either word is a stop word or too short
+ if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 {
+ continue
+ }
+
+ // Check if this pair appears to be a meaningful phrase
+ // Simple heuristic: consecutive non-stop words of reasonable length
+ phrase := word1 + " " + word2
+ phrases = append(phrases, phrase)
+
+ // Optionally check for 3-word phrases
+ if i < len(words)-2 {
+ word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}")
+ if !isStopWord(word3) && len(word3) >= 2 {
+ phrases = append(phrases, word1+" "+word2+" "+word3)
+ }
+ }
+ }
+
+ return phrases
+}
+
+// countPhraseMatches returns the number of query phrases found in text
+func countPhraseMatches(text, query string) int {
+ phrases := detectPhrases(query)
+ if len(phrases) == 0 {
+ return 0
+ }
+ textLower := strings.ToLower(text)
+ count := 0
+ for _, phrase := range phrases {
+ if strings.Contains(textLower, phrase) {
+ count++
+ }
+ }
+ return count
+}
+
+// parseSlugIndices extracts batch and chunk indices from a slug
+// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
+func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
+ // Find the last two numbers separated by underscores
+ re := regexp.MustCompile(`_(\d+)_(\d+)$`)
+ matches := re.FindStringSubmatch(slug)
+ if matches == nil || len(matches) != 3 {
+ return 0, 0, false
+ }
+ batch, err1 := strconv.Atoi(matches[1])
+ chunk, err2 := strconv.Atoi(matches[2])
+ if err1 != nil || err2 != nil {
+ return 0, 0, false
+ }
+ return batch, chunk, true
+}
+
+// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices
+func areSlugsAdjacent(slug1, slug2 string) bool {
+ // Extract filename prefix (everything before the last underscore sequence)
+ parts1 := strings.Split(slug1, "_")
+ parts2 := strings.Split(slug2, "_")
+ if len(parts1) < 3 || len(parts2) < 3 {
+ return false
+ }
+
+ // Compare filename prefixes (all parts except last two)
+ prefix1 := strings.Join(parts1[:len(parts1)-2], "_")
+ prefix2 := strings.Join(parts2[:len(parts2)-2], "_")
+ if prefix1 != prefix2 {
+ return false
+ }
+
+ batch1, chunk1, ok1 := parseSlugIndices(slug1)
+ batch2, chunk2, ok2 := parseSlugIndices(slug2)
+ if !ok1 || !ok2 {
+ return false
+ }
+
+ // Check if they're in same batch and chunks are sequential
+ if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) {
+ return true
+ }
+
+ // Check if they're in sequential batches and chunk indices suggest continuity
+ // This is heuristic but useful for cross-batch adjacency
+ if (batch1 == batch2+1 && chunk1 == 0) || (batch2 == batch1+1 && chunk2 == 0) {
+ return true
+ }
+ return false
+}
+
type RAG struct {
- logger *slog.Logger
- store storage.FullRepo
- cfg *config.Config
- embedder Embedder
- storage *VectorStorage
- mu sync.Mutex
+ logger *slog.Logger
+ store storage.FullRepo
+ cfg *config.Config
+ embedder Embedder
+ storage *VectorStorage
+ mu sync.RWMutex
+ idleMu sync.Mutex
+ fallbackMsg string
+ idleTimer *time.Timer
+ idleTimeout time.Duration
+}
+
+// batchTask represents a single batch to be embedded
+type batchTask struct {
+ batchIndex int
+ paragraphs []string
+ filename string
+ totalBatches int
+}
+
+// batchResult represents the result of embedding a batch
+type batchResult struct {
+ batchIndex int
+ embeddings [][]float32
+ paragraphs []string
+ filename string
+}
+
+// sendStatusNonBlocking sends a status message without blocking
+func (r *RAG) sendStatusNonBlocking(status string) {
+ select {
+ case LongJobStatusCh <- status:
+ default:
+ r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status)
+ }
}
-func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG {
- // Initialize with API embedder by default, could be configurable later
- embedder := NewAPIEmbedder(l, cfg)
+func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) {
+ var embedder Embedder
+ var fallbackMsg string
+ if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" {
+ emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l)
+ if err != nil {
+ l.Error("failed to create ONNX embedder, falling back to API", "error", err)
+ fallbackMsg = err.Error()
+ embedder = NewAPIEmbedder(l, cfg)
+ } else {
+ embedder = emb
+ l.Info("using ONNX embedder", "model", cfg.EmbedModelPath, "dims", cfg.EmbedDims)
+ }
+ } else {
+ embedder = NewAPIEmbedder(l, cfg)
+ l.Info("using API embedder", "url", cfg.EmbedURL)
+ }
rag := &RAG{
- logger: l,
- store: s,
- cfg: cfg,
- embedder: embedder,
- storage: NewVectorStorage(l, s),
+ logger: l,
+ store: s,
+ cfg: cfg,
+ embedder: embedder,
+ storage: NewVectorStorage(l, s),
+ fallbackMsg: fallbackMsg,
+ idleTimeout: 30 * time.Second,
}
// Note: Vector tables are created via database migrations, not at runtime
- return rag
+ return rag, nil
+}
+
+func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
+ if len(sentences) == 0 {
+ return nil
+ }
+ if overlapWords >= wordLimit {
+ overlapWords = wordLimit / 2
+ }
+ var chunks []string
+ i := 0
+ for i < len(sentences) {
+ var chunkWords []string
+ wordCount := 0
+ j := i
+ for j < len(sentences) && wordCount <= int(wordLimit) {
+ sentence := sentences[j]
+ words := strings.Fields(sentence)
+ chunkWords = append(chunkWords, sentence)
+ wordCount += len(words)
+ j++
+ // If this sentence alone exceeds limit, still include it and stop
+ if wordCount > int(wordLimit) {
+ break
+ }
+ }
+ if len(chunkWords) == 0 {
+ break
+ }
+ chunk := strings.Join(chunkWords, " ")
+ chunks = append(chunks, chunk)
+ if j >= len(sentences) {
+ break
+ }
+ // Move i forward by skipping overlap
+ if overlapWords == 0 {
+ i = j
+ continue
+ }
+ // Calculate how many sentences to skip to achieve overlapWords
+ overlapRemaining := int(overlapWords)
+ newI := i
+ for newI < j && overlapRemaining > 0 {
+ words := len(strings.Fields(sentences[newI]))
+ overlapRemaining -= words
+ if overlapRemaining >= 0 {
+ newI++
+ }
+ }
+ if newI == i {
+ newI = j
+ }
+ i = newI
+ }
+ return chunks
}
-func wordCounter(sentence string) int {
- return len(strings.Split(strings.TrimSpace(sentence), " "))
+func sanitizeFTSQuery(query string) string {
+ // Keep double quotes for FTS5 phrase matching
+ // Remove other problematic characters
+ query = strings.ReplaceAll(query, "'", " ")
+ query = strings.ReplaceAll(query, ";", " ")
+ query = strings.ReplaceAll(query, "\\", " ")
+ query = strings.TrimSpace(query)
+ if query == "" {
+ return "*" // match all
+ }
+ return query
}
func (r *RAG) LoadRAG(fpath string) error {
+ return r.LoadRAGWithContext(context.Background(), fpath)
+}
+
+func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error {
r.mu.Lock()
defer r.mu.Unlock()
fileText, err := ExtractText(fpath)
@@ -61,11 +292,9 @@ func (r *RAG) LoadRAG(fpath string) error {
return err
}
r.logger.Debug("rag: loaded file", "fp", fpath)
- select {
- case LongJobStatusCh <- LoadedFileRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus)
- }
+
+ // Send initial status (non-blocking with retry)
+ r.sendStatusNonBlocking(LoadedFileRAGStatus)
tokenizer, err := english.NewSentenceTokenizer(nil)
if err != nil {
return err
@@ -75,31 +304,9 @@ func (r *RAG) LoadRAG(fpath string) error {
for i, s := range sentences {
sents[i] = s.Text
}
- // Group sentences into paragraphs based on word limit
- paragraphs := []string{}
- par := strings.Builder{}
- for i := 0; i < len(sents); i++ {
- if strings.TrimSpace(sents[i]) != "" {
- if par.Len() > 0 {
- par.WriteString(" ")
- }
- par.WriteString(sents[i])
- }
- if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) {
- paragraph := strings.TrimSpace(par.String())
- if paragraph != "" {
- paragraphs = append(paragraphs, paragraph)
- }
- par.Reset()
- }
- }
- // Handle any remaining content in the paragraph buffer
- if par.Len() > 0 {
- paragraph := strings.TrimSpace(par.String())
- if paragraph != "" {
- paragraphs = append(paragraphs, paragraph)
- }
- }
+
+ // Create chunks with overlap
+ paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords)
// Adjust batch size if needed
if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 {
r.cfg.RAGBatchSize = len(paragraphs)
@@ -107,98 +314,354 @@ func (r *RAG) LoadRAG(fpath string) error {
if len(paragraphs) == 0 {
return errors.New("no valid paragraphs found in file")
}
- // Process paragraphs in batches synchronously
- batchCount := 0
- for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize {
- end := i + r.cfg.RAGBatchSize
- if end > len(paragraphs) {
- end = len(paragraphs)
+ totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize
+ r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize)
+
+ // Determine concurrency level
+ concurrency := runtime.NumCPU()
+ if concurrency > totalBatches {
+ concurrency = totalBatches
+ }
+ if concurrency < 1 {
+ concurrency = 1
+ }
+ // If using ONNX embedder, limit concurrency to 1 due to mutex serialization
+ var isONNX bool
+ if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX {
+ concurrency = 1
+ }
+ embedderType := "API"
+ if isONNX {
+ embedderType = "ONNX"
+ }
+ r.logger.Debug("parallel embedding setup",
+ "total_batches", totalBatches,
+ "concurrency", concurrency,
+ "embedder", embedderType,
+ "batch_size", r.cfg.RAGBatchSize)
+
+ // Create context with timeout (30 minutes) and cancellation for error handling
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
+ defer cancel()
+
+ // Channels for task distribution and results
+ taskCh := make(chan batchTask, totalBatches)
+ resultCh := make(chan batchResult, totalBatches)
+ errorCh := make(chan error, totalBatches)
+
+ // Start worker goroutines
+ var wg sync.WaitGroup
+ for w := 0; w < concurrency; w++ {
+ wg.Add(1)
+ go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg)
+ }
+
+ // Close task channel after all tasks are sent (by separate goroutine)
+ go func() {
+ // Ensure task channel is closed when this goroutine exits
+ defer close(taskCh)
+ r.logger.Debug("task distributor started", "total_batches", totalBatches)
+ for i := 0; i < totalBatches; i++ {
+ start := i * r.cfg.RAGBatchSize
+ end := start + r.cfg.RAGBatchSize
+ if end > len(paragraphs) {
+ end = len(paragraphs)
+ }
+ batch := paragraphs[start:end]
+
+ // Filter empty paragraphs
+ nonEmptyBatch := make([]string, 0, len(batch))
+ for _, p := range batch {
+ if strings.TrimSpace(p) != "" {
+ nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
+ }
+ }
+
+ task := batchTask{
+ batchIndex: i,
+ paragraphs: nonEmptyBatch,
+ filename: path.Base(fpath),
+ totalBatches: totalBatches,
+ }
+
+ select {
+ case taskCh <- task:
+ r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch))
+ case <-ctx.Done():
+ r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches)
+ return
+ }
}
- batch := paragraphs[i:end]
- batchCount++
- // Filter empty paragraphs
- nonEmptyBatch := make([]string, 0, len(batch))
- for _, p := range batch {
- if strings.TrimSpace(p) != "" {
- nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p))
+ r.logger.Debug("task distributor finished", "batches_sent", totalBatches)
+ }()
+
+ // Wait for workers to finish and close result channel
+ go func() {
+ wg.Wait()
+ close(resultCh)
+ }()
+
+ // Process results in order and write to database
+ nextExpectedBatch := 0
+ resultsBuffer := make(map[int]batchResult)
+ filename := path.Base(fpath)
+ batchesProcessed := 0
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+
+ case err := <-errorCh:
+ // First error from any worker, cancel everything
+ cancel()
+ r.logger.Error("embedding worker failed", "error", err)
+ r.sendStatusNonBlocking(ErrRAGStatus)
+ return fmt.Errorf("embedding failed: %w", err)
+
+ case result, ok := <-resultCh:
+ if !ok {
+ // All results processed
+ resultCh = nil
+ r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches)
+ continue
+ }
+
+ // Store result in buffer
+ resultsBuffer[result.batchIndex] = result
+
+ // Process buffered results in order
+ for {
+ if res, exists := resultsBuffer[nextExpectedBatch]; exists {
+ // Write this batch to database
+ if err := r.writeBatchToStorage(ctx, res, filename); err != nil {
+ cancel()
+ return err
+ }
+
+ batchesProcessed++
+ // Send progress update
+ statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches)
+ r.sendStatusNonBlocking(statusMsg)
+
+ delete(resultsBuffer, nextExpectedBatch)
+ nextExpectedBatch++
+ } else {
+ break
+ }
}
+
+ default:
+ // No channels ready, check for deadlock conditions
+ if resultCh == nil && nextExpectedBatch < totalBatches {
+ // Missing batch results after result channel closed
+ r.logger.Error("missing batch results",
+ "expected", totalBatches,
+ "received", nextExpectedBatch,
+ "missing", totalBatches-nextExpectedBatch)
+
+ // Wait a short time for any delayed errors, then cancel
+ select {
+ case <-time.After(5 * time.Second):
+ cancel()
+ return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch)
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-errorCh:
+ cancel()
+ r.logger.Error("embedding worker failed after result channel closed", "error", err)
+ r.sendStatusNonBlocking(ErrRAGStatus)
+ return fmt.Errorf("embedding failed: %w", err)
+ }
+ }
+ // If we reach here, no deadlock yet, just busy loop prevention
+ time.Sleep(100 * time.Millisecond)
}
- if len(nonEmptyBatch) == 0 {
- continue
+
+ // Check if we're done
+ if resultCh == nil && nextExpectedBatch >= totalBatches {
+ r.logger.Debug("all batches processed successfully", "total", totalBatches)
+ break
}
- // Embed the batch
- embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch)
- if err != nil {
- r.logger.Error("failed to embed batch", "error", err, "batch", batchCount)
+ }
+ r.logger.Debug("finished writing vectors", "batches", batchesProcessed)
+ r.resetIdleTimer()
+ r.sendStatusNonBlocking(FinishedRAGStatus)
+ return nil
+}
+
+// embeddingWorker processes batch embedding tasks
+func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) {
+ defer wg.Done()
+ r.logger.Debug("embedding worker started", "worker", workerID)
+
+ // Panic recovery to ensure worker doesn't crash silently
+ defer func() {
+ if rec := recover(); rec != nil {
+ r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec)
+ // Try to send error, but don't block if channel is full
select {
- case LongJobStatusCh <- ErrRAGStatus:
+ case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec):
default:
- r.logger.Warn("LongJobStatusCh channel full, dropping message")
- }
- return fmt.Errorf("failed to embed batch %d: %w", batchCount, err)
- }
- if len(embeddings) != len(nonEmptyBatch) {
- err := errors.New("embedding count mismatch")
- r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings))
- return err
- }
- // Write vectors to storage
- filename := path.Base(fpath)
- for j, text := range nonEmptyBatch {
- vector := models.VectorRow{
- Embeddings: embeddings[j],
- RawText: text,
- Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j),
- FileName: filename,
- }
- if err := r.storage.WriteVector(&vector); err != nil {
- r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug)
- select {
- case LongJobStatusCh <- ErrRAGStatus:
- default:
- r.logger.Warn("LongJobStatusCh channel full, dropping message")
- }
- return fmt.Errorf("failed to write vector: %w", err)
+ r.logger.Warn("error channel full, dropping panic error", "worker", workerID)
}
}
- r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch))
- // Send progress status
- statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize)
+ }()
+ for task := range taskCh {
select {
- case LongJobStatusCh <- statusMsg:
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled", "worker", workerID)
+ return
default:
- r.logger.Warn("LongJobStatusCh channel full, dropping message")
}
+ r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches)
+
+ // Skip empty batches
+ if len(task.paragraphs) == 0 {
+ select {
+ case resultCh <- batchResult{
+ batchIndex: task.batchIndex,
+ embeddings: nil,
+ paragraphs: nil,
+ filename: task.filename,
+ }:
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID)
+ return
+ }
+ r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex)
+ continue
+ }
+ // Embed with retry for API embedder
+ embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3)
+ if err != nil {
+ // Try to send error, but don't block indefinitely
+ select {
+ case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err):
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID)
+ }
+ return
+ }
+ // Send result with context awareness
+ select {
+ case resultCh <- batchResult{
+ batchIndex: task.batchIndex,
+ embeddings: embeddings,
+ paragraphs: task.paragraphs,
+ filename: task.filename,
+ }:
+ case <-ctx.Done():
+ r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID)
+ return
+ }
+ r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings))
}
- r.logger.Debug("finished writing vectors", "batches", batchCount)
+ r.logger.Debug("embedding worker finished", "worker", workerID)
+}
+
+// embedWithRetry attempts embedding with exponential backoff for API embedder
+func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) {
+ var lastErr error
+ for attempt := 0; attempt < maxRetries; attempt++ {
+ if attempt > 0 {
+ // Exponential backoff
+ backoff := time.Duration(attempt*attempt) * time.Second
+ if backoff > 10*time.Second {
+ backoff = 10 * time.Second
+ }
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries)
+ }
+
+ embeddings, err := r.embedder.EmbedSlice(paragraphs)
+ if err == nil {
+ // Validate embedding count
+ if len(embeddings) != len(paragraphs) {
+ return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings))
+ }
+ return embeddings, nil
+ }
+
+ lastErr = err
+ // Only retry for API embedder errors (network/timeout)
+ // For ONNX embedder, fail fast
+ if _, isAPI := r.embedder.(*APIEmbedder); !isAPI {
+ break
+ }
+ }
+ return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr)
+}
+
+// writeBatchToStorage writes a single batch of vectors to the database
+func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error {
+ if len(result.embeddings) == 0 {
+ // Empty batch, skip
+ return nil
+ }
+ // Check context before starting
select {
- case LongJobStatusCh <- FinishedRAGStatus:
+ case <-ctx.Done():
+ return ctx.Err()
default:
- r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus)
}
+
+ // Build all vectors for batch write
+ vectors := make([]*models.VectorRow, 0, len(result.paragraphs))
+ for j, text := range result.paragraphs {
+ vectors = append(vectors, &models.VectorRow{
+ Embeddings: result.embeddings[j],
+ RawText: text,
+ Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j),
+ FileName: filename,
+ })
+ }
+
+ // Write all vectors in a single transaction
+ if err := r.storage.WriteVectors(vectors); err != nil {
+ r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors))
+ r.sendStatusNonBlocking(ErrRAGStatus)
+ return fmt.Errorf("failed to write vectors batch: %w", err)
+ }
+ r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs))
return nil
}
func (r *RAG) LineToVector(line string) ([]float32, error) {
+ r.resetIdleTimer()
return r.embedder.Embed(line)
}
-func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) {
- return r.storage.SearchClosest(emb.Embedding)
+func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) {
+ r.resetIdleTimer()
+ return r.storage.SearchClosest(emb.Embedding, limit)
+}
+
+func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) {
+ r.resetIdleTimer()
+ sanitized := sanitizeFTSQuery(query)
+ return r.storage.SearchKeyword(sanitized, limit)
}
func (r *RAG) ListLoaded() ([]string, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
return r.storage.ListFiles()
}
func (r *RAG) RemoveFile(filename string) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.resetIdleTimer()
return r.storage.RemoveEmbByFileName(filename)
}
var (
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
- stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
)
func (r *RAG) RefineQuery(query string) string {
@@ -210,11 +673,31 @@ func (r *RAG) RefineQuery(query string) string {
if len(query) <= 3 {
return original
}
+ // If query already contains double quotes, assume it's a phrase query and skip refinement
+ if strings.Contains(query, "\"") {
+ return original
+ }
query = strings.ToLower(query)
- for _, stopWord := range stopWords {
- wordPattern := `\b` + stopWord + `\b`
- re := regexp.MustCompile(wordPattern)
- query = re.ReplaceAllString(query, "")
+ words := strings.Fields(query)
+ if len(words) >= 3 {
+ // Detect phrases and protect words that are part of phrases
+ phrases := detectPhrases(query)
+ protectedWords := make(map[string]bool)
+ for _, phrase := range phrases {
+ for _, word := range strings.Fields(phrase) {
+ protectedWords[word] = true
+ }
+ }
+
+ // Remove stop words that are not protected
+ for _, stopWord := range stopWords {
+ if protectedWords[stopWord] {
+ continue
+ }
+ wordPattern := `\b` + stopWord + `\b`
+ re := regexp.MustCompile(wordPattern)
+ query = re.ReplaceAllString(query, "")
+ }
}
query = strings.TrimSpace(query)
if len(query) < 5 {
@@ -246,7 +729,7 @@ func (r *RAG) extractImportantPhrases(query string) string {
break
}
}
- if isImportant || len(word) > 3 {
+ if isImportant || len(word) >= 3 {
important = append(important, word)
}
}
@@ -265,6 +748,36 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if len(parts) == 0 {
return variations
}
+ // Get loaded filenames to filter out filename terms
+ filenames, err := r.storage.ListFiles()
+ if err == nil && len(filenames) > 0 {
+ // Convert to lowercase for case-insensitive matching
+ lowerFilenames := make([]string, len(filenames))
+ for i, f := range filenames {
+ lowerFilenames[i] = strings.ToLower(f)
+ }
+ filteredParts := make([]string, 0, len(parts))
+ for _, part := range parts {
+ partLower := strings.ToLower(part)
+ skip := false
+ for _, fn := range lowerFilenames {
+ if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) {
+ skip = true
+ break
+ }
+ }
+ if !skip {
+ filteredParts = append(filteredParts, part)
+ }
+ }
+ // If filteredParts not empty and different from original, add filtered query
+ if len(filteredParts) > 0 && len(filteredParts) != len(parts) {
+ filteredQuery := strings.Join(filteredParts, " ")
+ if len(filteredQuery) >= 5 {
+ variations = append(variations, filteredQuery)
+ }
+ }
+ }
if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 {
@@ -289,13 +802,57 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if !strings.HasSuffix(query, " summary") {
variations = append(variations, query+" summary")
}
+
+ // Add phrase-quoted variations for better FTS5 matching
+ phrases := detectPhrases(query)
+ if len(phrases) > 0 {
+ // Sort phrases by length descending to prioritize longer phrases
+ sort.Slice(phrases, func(i, j int) bool {
+ return len(phrases[i]) > len(phrases[j])
+ })
+
+ // Create a version with all phrases quoted
+ quotedQuery := query
+ for _, phrase := range phrases {
+ // Only quote if not already quoted
+ quotedPhrase := "\"" + phrase + "\""
+ if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) {
+ // Case-insensitive replacement of phrase with quoted version
+ re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
+ quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
+ }
+ }
+ // Disabled malformed quoted query for now
+ // if quotedQuery != query {
+ // variations = append(variations, quotedQuery)
+ // }
+
+ // Also add individual phrase variations for short queries
+ if len(phrases) <= 5 {
+ for _, phrase := range phrases {
+ // Create a focused query with just this phrase quoted
+ // Keep original context but emphasize this phrase
+ quotedPhrase := "\"" + phrase + "\""
+ re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
+ focusedQuery := re.ReplaceAllString(query, quotedPhrase)
+ if focusedQuery != query && focusedQuery != quotedQuery {
+ variations = append(variations, focusedQuery)
+ }
+ // Add the phrase alone (quoted) as a separate variation
+ variations = append(variations, quotedPhrase)
+ }
+ }
+ }
+
return variations
}
func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
+ phraseCount := len(detectPhrases(query))
type scoredResult struct {
- row models.VectorRow
- distance float32
+ row models.VectorRow
+ distance float32
+ phraseMatches int
}
scored := make([]scoredResult, 0, len(results))
for i := range results {
@@ -320,27 +877,69 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
score += 3
}
+
+ // Phrase match bonus: extra points for containing detected phrases
+ phraseMatches := countPhraseMatches(row.RawText, query)
+ if phraseMatches > 0 {
+ // Significant bonus per phrase to prioritize exact phrase matches
+ r.logger.Debug("phrase match bonus", "slug", row.Slug, "phraseMatches", phraseMatches, "score", score)
+ score += float32(phraseMatches) * 100
+ }
+
+ // Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
+ // boost score to promote narrative continuity
+ adjacentCount := 0
+ for _, other := range results {
+ if other.Slug == row.Slug {
+ continue
+ }
+ if areSlugsAdjacent(row.Slug, other.Slug) {
+ adjacentCount++
+ }
+ }
+ if adjacentCount > 0 {
+ // Bonus per adjacent chunk, but diminishing returns
+ score += float32(adjacentCount) * 4
+ }
distance := row.Distance - score/100
- scored = append(scored, scoredResult{row: row, distance: distance})
+ scored = append(scored, scoredResult{row: row, distance: distance, phraseMatches: phraseMatches})
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].distance < scored[j].distance
})
unique := make([]models.VectorRow, 0)
seen := make(map[string]bool)
+ maxPerFile := 2
+ if phraseCount > 0 {
+ maxPerFile = 10
+ }
+ fileCounts := make(map[string]int)
for i := range scored {
if !seen[scored[i].row.Slug] {
+ // Allow phrase-matching chunks to bypass per-file limit (up to +5 extra)
+ allowed := fileCounts[scored[i].row.FileName] < maxPerFile
+ if !allowed && scored[i].phraseMatches > 0 {
+ // If chunk has phrase matches, allow extra slots (up to maxPerFile + 5)
+ allowed = fileCounts[scored[i].row.FileName] < maxPerFile+5
+ }
+ if !allowed {
+ continue
+ }
seen[scored[i].row.Slug] = true
+ fileCounts[scored[i].row.FileName]++
unique = append(unique, scored[i].row)
}
}
- if len(unique) > 10 {
- unique = unique[:10]
+ if len(unique) > 30 {
+ unique = unique[:30]
}
return unique
}
func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.resetIdleTimer()
if len(results) == 0 {
return "No relevant information found in the vector database.", nil
}
@@ -369,7 +968,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string
Embedding: emb,
Index: 0,
}
- topResults, err := r.SearchEmb(embResp)
+ topResults, err := r.searchEmb(embResp, 1)
if err != nil {
r.logger.Error("failed to search for synthesis context", "error", err)
return "", err
@@ -396,9 +995,15 @@ func truncateString(s string, maxLen int) string {
}
func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.resetIdleTimer()
refined := r.RefineQuery(query)
variations := r.GenerateQueryVariations(refined)
- allResults := make([]models.VectorRow, 0)
+ r.logger.Debug("query variations", "original", query, "refined", refined, "variations", variations)
+
+ // Collect embedding search results from all variations
+ var embResults []models.VectorRow
seen := make(map[string]bool)
for _, q := range variations {
emb, err := r.LineToVector(q)
@@ -406,29 +1011,119 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
r.logger.Error("failed to embed query variation", "error", err, "query", q)
continue
}
-
embResp := &models.EmbeddingResp{
Embedding: emb,
Index: 0,
}
-
- results, err := r.SearchEmb(embResp)
+ results, err := r.searchEmb(embResp, limit*2) // Get more candidates
if err != nil {
r.logger.Error("failed to search embeddings", "error", err, "query", q)
continue
}
-
for _, row := range results {
if !seen[row.Slug] {
seen[row.Slug] = true
- allResults = append(allResults, row)
+ embResults = append(embResults, row)
}
}
}
- reranked := r.RerankResults(allResults, query)
- if len(reranked) > limit {
- reranked = reranked[:limit]
+ // Sort embedding results by distance (lower is better)
+ sort.Slice(embResults, func(i, j int) bool {
+ return embResults[i].Distance < embResults[j].Distance
+ })
+
+ // Perform keyword search on all variations
+ var kwResults []models.VectorRow
+ seenKw := make(map[string]bool)
+ for _, q := range variations {
+ results, err := r.searchKeyword(q, limit)
+ if err != nil {
+ r.logger.Debug("keyword search failed for variation", "error", err, "query", q)
+ continue
+ }
+ for _, row := range results {
+ if !seenKw[row.Slug] {
+ seenKw[row.Slug] = true
+ kwResults = append(kwResults, row)
+ }
+ }
+ }
+ // Sort keyword results by distance (lower is better)
+ sort.Slice(kwResults, func(i, j int) bool {
+ return kwResults[i].Distance < kwResults[j].Distance
+ })
+
+ // Combine using Reciprocal Rank Fusion (RRF)
+ // Use smaller K for phrase-heavy queries to give more weight to top ranks
+ phraseCount := len(detectPhrases(query))
+ rrfK := 60.0
+ if phraseCount > 0 {
+ rrfK = 30.0
+ }
+ r.logger.Debug("RRF parameters", "phraseCount", phraseCount, "rrfK", rrfK, "query", query)
+ type scoredRow struct {
+ row models.VectorRow
+ score float64
+ }
+ scoreMap := make(map[string]float64)
+ // Add embedding results
+ for rank, row := range embResults {
+ score := 1.0 / (float64(rank) + rrfK)
+ scoreMap[row.Slug] += score
+ if row.Slug == "kjv_bible.epub_1786_0" {
+ r.logger.Debug("target chunk embedding rank", "rank", rank, "score", score)
+ }
+ }
+ // Add keyword results with weight boost when phrases are present
+ kwWeight := 1.0
+ if phraseCount > 0 {
+ kwWeight = 100.0
+ }
+ r.logger.Debug("keyword weight", "kwWeight", kwWeight, "phraseCount", phraseCount)
+ for rank, row := range kwResults {
+ score := kwWeight * (1.0 / (float64(rank) + rrfK))
+ scoreMap[row.Slug] += score
+ if row.Slug == "kjv_bible.epub_1786_0" {
+ r.logger.Debug("target chunk keyword rank", "rank", rank, "score", score, "kwWeight", kwWeight, "rrfK", rrfK)
+ }
+ // Ensure row exists in combined results
+ if _, exists := seen[row.Slug]; !exists {
+ embResults = append(embResults, row)
+ }
}
+ // Create slice of scored rows
+ scoredRows := make([]scoredRow, 0, len(embResults))
+ for _, row := range embResults {
+ score := scoreMap[row.Slug]
+ scoredRows = append(scoredRows, scoredRow{row: row, score: score})
+ }
+ // Debug: log scores for target chunk and top chunks
+ if strings.Contains(strings.ToLower(query), "bald") || strings.Contains(strings.ToLower(query), "she bears") {
+ for _, sr := range scoredRows {
+ if sr.row.Slug == "kjv_bible.epub_1786_0" {
+ r.logger.Debug("target chunk score", "slug", sr.row.Slug, "score", sr.score, "distance", sr.row.Distance)
+ }
+ }
+ // Log top 5 scores
+ for i := 0; i < len(scoredRows) && i < 5; i++ {
+ r.logger.Debug("top scored row", "rank", i+1, "slug", scoredRows[i].row.Slug, "score", scoredRows[i].score, "distance", scoredRows[i].row.Distance)
+ }
+ }
+ // Sort by descending RRF score
+ sort.Slice(scoredRows, func(i, j int) bool {
+ return scoredRows[i].score > scoredRows[j].score
+ })
+ // Take top limit
+ if len(scoredRows) > limit {
+ scoredRows = scoredRows[:limit]
+ }
+ // Convert back to VectorRow
+ finalResults := make([]models.VectorRow, len(scoredRows))
+ for i, sr := range scoredRows {
+ finalResults[i] = sr.row
+ }
+ // Apply reranking heuristics
+ reranked := r.RerankResults(finalResults, query)
return reranked, nil
}
@@ -437,16 +1132,66 @@ var (
ragOnce sync.Once
)
+func (r *RAG) FallbackMessage() string {
+ return r.fallbackMsg
+}
+
func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
+ var err error
ragOnce.Do(func() {
if c == nil || l == nil || s == nil {
return
}
- ragInstance = New(l, s, c)
+ ragInstance, err = New(l, s, c)
})
- return nil
+ return err
}
func GetInstance() *RAG {
return ragInstance
}
+
+func (r *RAG) resetIdleTimer() {
+ r.idleMu.Lock()
+ defer r.idleMu.Unlock()
+ if r.idleTimer != nil {
+ r.idleTimer.Stop()
+ }
+ r.idleTimer = time.AfterFunc(r.idleTimeout, func() {
+ r.freeONNXMemory()
+ })
+}
+
+func (r *RAG) freeONNXMemory() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
+ if err := onnx.Destroy(); err != nil {
+ r.logger.Error("failed to free ONNX memory", "error", err)
+ } else {
+ r.logger.Info("freed ONNX VRAM after idle timeout")
+ }
+ }
+}
+
+func (r *RAG) Destroy() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.idleTimer != nil {
+ r.idleTimer.Stop()
+ r.idleTimer = nil
+ }
+ if onnx, ok := r.embedder.(*ONNXEmbedder); ok {
+ if err := onnx.Destroy(); err != nil {
+ r.logger.Error("failed to destroy ONNX embedder", "error", err)
+ }
+ }
+}
+
+// SetEmbedderForTesting replaces the internal embedder with a mock.
+// This function is only available when compiling with the "test" build tag.
+func (r *RAG) SetEmbedderForTesting(e Embedder) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.embedder = e
+}