diff options
Diffstat (limited to 'rag/rag.go')
| -rw-r--r-- | rag/rag.go | 1001 |
1 files changed, 873 insertions, 128 deletions
@@ -1,6 +1,7 @@ package rag import ( + "context" "errors" "fmt" "gf-lt/config" @@ -9,51 +10,281 @@ import ( "log/slog" "path" "regexp" + "runtime" "sort" + "strconv" "strings" "sync" + "time" "github.com/neurosnap/sentences/english" ) +const () + var ( // Status messages for TUI integration - LongJobStatusCh = make(chan string, 10) // Increased buffer size to prevent blocking - FinishedRAGStatus = "finished loading RAG file; press Enter" + LongJobStatusCh = make(chan string, 100) // Increased buffer size for parallel batch updates + FinishedRAGStatus = "finished loading RAG file; press x to exit" LoadedFileRAGStatus = "loaded file" ErrRAGStatus = "some error occurred; failed to transfer data to vector db" + + // stopWords are common words that can be removed from queries when not part of phrases + stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"} ) +// isStopWord checks if a word is in the stop words list +func isStopWord(word string) bool { + for _, stop := range stopWords { + if strings.EqualFold(word, stop) { + return true + } + } + return false +} + +// detectPhrases returns multi-word phrases from a query that should be treated as units +func detectPhrases(query string) []string { + words := strings.Fields(strings.ToLower(query)) + var phrases []string + + for i := 0; i < len(words)-1; i++ { + word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}") + word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}") + + // Skip if either word is a stop word or too short + if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 { + continue + } + + // Check if this pair appears to be a meaningful phrase + // Simple heuristic: consecutive non-stop words of reasonable length + phrase := word1 + " " + word2 + phrases = append(phrases, phrase) + + // Optionally check for 3-word phrases + if i < len(words)-2 { + word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}") + if !isStopWord(word3) && len(word3) >= 2 { + phrases = append(phrases, word1+" "+word2+" "+word3) + } + } + } + + return phrases +} + +// countPhraseMatches returns the number of query phrases found in text +func countPhraseMatches(text, query string) int { + phrases := detectPhrases(query) + if len(phrases) == 0 { + return 0 + } + textLower := strings.ToLower(text) + count := 0 + for _, phrase := range phrases { + if strings.Contains(textLower, phrase) { + count++ + } + } + return count +} + +// parseSlugIndices extracts batch and chunk indices from a slug +// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0") +func parseSlugIndices(slug string) (batch, chunk int, ok bool) { + // Find the last two numbers separated by underscores + re := regexp.MustCompile(`_(\d+)_(\d+)$`) + matches := re.FindStringSubmatch(slug) + if matches == nil || len(matches) != 3 { + return 0, 0, false + } + batch, err1 := strconv.Atoi(matches[1]) + chunk, err2 := strconv.Atoi(matches[2]) + if err1 != nil || err2 != nil { + return 0, 0, false + } + return batch, chunk, true +} + +// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices +func areSlugsAdjacent(slug1, slug2 string) bool { + // Extract filename prefix (everything before the last underscore sequence) + parts1 := strings.Split(slug1, "_") + parts2 := strings.Split(slug2, "_") + if len(parts1) < 3 || len(parts2) < 3 { + return false + } + + // Compare filename prefixes (all parts except last two) + prefix1 := strings.Join(parts1[:len(parts1)-2], "_") + prefix2 := strings.Join(parts2[:len(parts2)-2], "_") + if prefix1 != prefix2 { + return false + } + + batch1, chunk1, ok1 := parseSlugIndices(slug1) + batch2, chunk2, ok2 := parseSlugIndices(slug2) + if !ok1 || !ok2 { + return false + } + + // Check if they're in same batch and chunks are sequential + if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) { + return true + } + + // Check if they're in sequential batches and chunk indices suggest continuity + // This is heuristic but useful for cross-batch adjacency + if (batch1 == batch2+1 && chunk1 == 0) || (batch2 == batch1+1 && chunk2 == 0) { + return true + } + return false +} + type RAG struct { - logger *slog.Logger - store storage.FullRepo - cfg *config.Config - embedder Embedder - storage *VectorStorage - mu sync.Mutex + logger *slog.Logger + store storage.FullRepo + cfg *config.Config + embedder Embedder + storage *VectorStorage + mu sync.RWMutex + idleMu sync.Mutex + fallbackMsg string + idleTimer *time.Timer + idleTimeout time.Duration +} + +// batchTask represents a single batch to be embedded +type batchTask struct { + batchIndex int + paragraphs []string + filename string + totalBatches int +} + +// batchResult represents the result of embedding a batch +type batchResult struct { + batchIndex int + embeddings [][]float32 + paragraphs []string + filename string +} + +// sendStatusNonBlocking sends a status message without blocking +func (r *RAG) sendStatusNonBlocking(status string) { + select { + case LongJobStatusCh <- status: + default: + r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", status) + } } -func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { - // Initialize with API embedder by default, could be configurable later - embedder := NewAPIEmbedder(l, cfg) +func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) (*RAG, error) { + var embedder Embedder + var fallbackMsg string + if cfg.EmbedModelPath != "" && cfg.EmbedTokenizerPath != "" { + emb, err := NewONNXEmbedder(cfg.EmbedModelPath, cfg.EmbedTokenizerPath, cfg.EmbedDims, l) + if err != nil { + l.Error("failed to create ONNX embedder, falling back to API", "error", err) + fallbackMsg = err.Error() + embedder = NewAPIEmbedder(l, cfg) + } else { + embedder = emb + l.Info("using ONNX embedder", "model", cfg.EmbedModelPath, "dims", cfg.EmbedDims) + } + } else { + embedder = NewAPIEmbedder(l, cfg) + l.Info("using API embedder", "url", cfg.EmbedURL) + } rag := &RAG{ - logger: l, - store: s, - cfg: cfg, - embedder: embedder, - storage: NewVectorStorage(l, s), + logger: l, + store: s, + cfg: cfg, + embedder: embedder, + storage: NewVectorStorage(l, s), + fallbackMsg: fallbackMsg, + idleTimeout: 30 * time.Second, } // Note: Vector tables are created via database migrations, not at runtime - return rag + return rag, nil +} + +func createChunks(sentences []string, wordLimit, overlapWords uint32) []string { + if len(sentences) == 0 { + return nil + } + if overlapWords >= wordLimit { + overlapWords = wordLimit / 2 + } + var chunks []string + i := 0 + for i < len(sentences) { + var chunkWords []string + wordCount := 0 + j := i + for j < len(sentences) && wordCount <= int(wordLimit) { + sentence := sentences[j] + words := strings.Fields(sentence) + chunkWords = append(chunkWords, sentence) + wordCount += len(words) + j++ + // If this sentence alone exceeds limit, still include it and stop + if wordCount > int(wordLimit) { + break + } + } + if len(chunkWords) == 0 { + break + } + chunk := strings.Join(chunkWords, " ") + chunks = append(chunks, chunk) + if j >= len(sentences) { + break + } + // Move i forward by skipping overlap + if overlapWords == 0 { + i = j + continue + } + // Calculate how many sentences to skip to achieve overlapWords + overlapRemaining := int(overlapWords) + newI := i + for newI < j && overlapRemaining > 0 { + words := len(strings.Fields(sentences[newI])) + overlapRemaining -= words + if overlapRemaining >= 0 { + newI++ + } + } + if newI == i { + newI = j + } + i = newI + } + return chunks } -func wordCounter(sentence string) int { - return len(strings.Split(strings.TrimSpace(sentence), " ")) +func sanitizeFTSQuery(query string) string { + // Keep double quotes for FTS5 phrase matching + // Remove other problematic characters + query = strings.ReplaceAll(query, "'", " ") + query = strings.ReplaceAll(query, ";", " ") + query = strings.ReplaceAll(query, "\\", " ") + query = strings.TrimSpace(query) + if query == "" { + return "*" // match all + } + return query } func (r *RAG) LoadRAG(fpath string) error { + return r.LoadRAGWithContext(context.Background(), fpath) +} + +func (r *RAG) LoadRAGWithContext(ctx context.Context, fpath string) error { r.mu.Lock() defer r.mu.Unlock() fileText, err := ExtractText(fpath) @@ -61,11 +292,9 @@ func (r *RAG) LoadRAG(fpath string) error { return err } r.logger.Debug("rag: loaded file", "fp", fpath) - select { - case LongJobStatusCh <- LoadedFileRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", LoadedFileRAGStatus) - } + + // Send initial status (non-blocking with retry) + r.sendStatusNonBlocking(LoadedFileRAGStatus) tokenizer, err := english.NewSentenceTokenizer(nil) if err != nil { return err @@ -75,31 +304,9 @@ func (r *RAG) LoadRAG(fpath string) error { for i, s := range sentences { sents[i] = s.Text } - // Group sentences into paragraphs based on word limit - paragraphs := []string{} - par := strings.Builder{} - for i := 0; i < len(sents); i++ { - if strings.TrimSpace(sents[i]) != "" { - if par.Len() > 0 { - par.WriteString(" ") - } - par.WriteString(sents[i]) - } - if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { - paragraph := strings.TrimSpace(par.String()) - if paragraph != "" { - paragraphs = append(paragraphs, paragraph) - } - par.Reset() - } - } - // Handle any remaining content in the paragraph buffer - if par.Len() > 0 { - paragraph := strings.TrimSpace(par.String()) - if paragraph != "" { - paragraphs = append(paragraphs, paragraph) - } - } + + // Create chunks with overlap + paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords) // Adjust batch size if needed if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 { r.cfg.RAGBatchSize = len(paragraphs) @@ -107,98 +314,354 @@ func (r *RAG) LoadRAG(fpath string) error { if len(paragraphs) == 0 { return errors.New("no valid paragraphs found in file") } - // Process paragraphs in batches synchronously - batchCount := 0 - for i := 0; i < len(paragraphs); i += r.cfg.RAGBatchSize { - end := i + r.cfg.RAGBatchSize - if end > len(paragraphs) { - end = len(paragraphs) + totalBatches := (len(paragraphs) + r.cfg.RAGBatchSize - 1) / r.cfg.RAGBatchSize + r.logger.Debug("starting parallel embedding", "total_batches", totalBatches, "batch_size", r.cfg.RAGBatchSize) + + // Determine concurrency level + concurrency := runtime.NumCPU() + if concurrency > totalBatches { + concurrency = totalBatches + } + if concurrency < 1 { + concurrency = 1 + } + // If using ONNX embedder, limit concurrency to 1 due to mutex serialization + var isONNX bool + if _, isONNX = r.embedder.(*ONNXEmbedder); isONNX { + concurrency = 1 + } + embedderType := "API" + if isONNX { + embedderType = "ONNX" + } + r.logger.Debug("parallel embedding setup", + "total_batches", totalBatches, + "concurrency", concurrency, + "embedder", embedderType, + "batch_size", r.cfg.RAGBatchSize) + + // Create context with timeout (30 minutes) and cancellation for error handling + ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) + defer cancel() + + // Channels for task distribution and results + taskCh := make(chan batchTask, totalBatches) + resultCh := make(chan batchResult, totalBatches) + errorCh := make(chan error, totalBatches) + + // Start worker goroutines + var wg sync.WaitGroup + for w := 0; w < concurrency; w++ { + wg.Add(1) + go r.embeddingWorker(ctx, w, taskCh, resultCh, errorCh, &wg) + } + + // Close task channel after all tasks are sent (by separate goroutine) + go func() { + // Ensure task channel is closed when this goroutine exits + defer close(taskCh) + r.logger.Debug("task distributor started", "total_batches", totalBatches) + for i := 0; i < totalBatches; i++ { + start := i * r.cfg.RAGBatchSize + end := start + r.cfg.RAGBatchSize + if end > len(paragraphs) { + end = len(paragraphs) + } + batch := paragraphs[start:end] + + // Filter empty paragraphs + nonEmptyBatch := make([]string, 0, len(batch)) + for _, p := range batch { + if strings.TrimSpace(p) != "" { + nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p)) + } + } + + task := batchTask{ + batchIndex: i, + paragraphs: nonEmptyBatch, + filename: path.Base(fpath), + totalBatches: totalBatches, + } + + select { + case taskCh <- task: + r.logger.Debug("task distributor sent batch", "batch", i, "paragraphs", len(nonEmptyBatch)) + case <-ctx.Done(): + r.logger.Debug("task distributor cancelled", "batches_sent", i+1, "total_batches", totalBatches) + return + } } - batch := paragraphs[i:end] - batchCount++ - // Filter empty paragraphs - nonEmptyBatch := make([]string, 0, len(batch)) - for _, p := range batch { - if strings.TrimSpace(p) != "" { - nonEmptyBatch = append(nonEmptyBatch, strings.TrimSpace(p)) + r.logger.Debug("task distributor finished", "batches_sent", totalBatches) + }() + + // Wait for workers to finish and close result channel + go func() { + wg.Wait() + close(resultCh) + }() + + // Process results in order and write to database + nextExpectedBatch := 0 + resultsBuffer := make(map[int]batchResult) + filename := path.Base(fpath) + batchesProcessed := 0 + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case err := <-errorCh: + // First error from any worker, cancel everything + cancel() + r.logger.Error("embedding worker failed", "error", err) + r.sendStatusNonBlocking(ErrRAGStatus) + return fmt.Errorf("embedding failed: %w", err) + + case result, ok := <-resultCh: + if !ok { + // All results processed + resultCh = nil + r.logger.Debug("result channel closed", "batches_processed", batchesProcessed, "total_batches", totalBatches) + continue + } + + // Store result in buffer + resultsBuffer[result.batchIndex] = result + + // Process buffered results in order + for { + if res, exists := resultsBuffer[nextExpectedBatch]; exists { + // Write this batch to database + if err := r.writeBatchToStorage(ctx, res, filename); err != nil { + cancel() + return err + } + + batchesProcessed++ + // Send progress update + statusMsg := fmt.Sprintf("processed batch %d/%d", batchesProcessed, totalBatches) + r.sendStatusNonBlocking(statusMsg) + + delete(resultsBuffer, nextExpectedBatch) + nextExpectedBatch++ + } else { + break + } } + + default: + // No channels ready, check for deadlock conditions + if resultCh == nil && nextExpectedBatch < totalBatches { + // Missing batch results after result channel closed + r.logger.Error("missing batch results", + "expected", totalBatches, + "received", nextExpectedBatch, + "missing", totalBatches-nextExpectedBatch) + + // Wait a short time for any delayed errors, then cancel + select { + case <-time.After(5 * time.Second): + cancel() + return fmt.Errorf("missing batch results: expected %d, got %d", totalBatches, nextExpectedBatch) + case <-ctx.Done(): + return ctx.Err() + case err := <-errorCh: + cancel() + r.logger.Error("embedding worker failed after result channel closed", "error", err) + r.sendStatusNonBlocking(ErrRAGStatus) + return fmt.Errorf("embedding failed: %w", err) + } + } + // If we reach here, no deadlock yet, just busy loop prevention + time.Sleep(100 * time.Millisecond) } - if len(nonEmptyBatch) == 0 { - continue + + // Check if we're done + if resultCh == nil && nextExpectedBatch >= totalBatches { + r.logger.Debug("all batches processed successfully", "total", totalBatches) + break } - // Embed the batch - embeddings, err := r.embedder.EmbedSlice(nonEmptyBatch) - if err != nil { - r.logger.Error("failed to embed batch", "error", err, "batch", batchCount) + } + r.logger.Debug("finished writing vectors", "batches", batchesProcessed) + r.resetIdleTimer() + r.sendStatusNonBlocking(FinishedRAGStatus) + return nil +} + +// embeddingWorker processes batch embedding tasks +func (r *RAG) embeddingWorker(ctx context.Context, workerID int, taskCh <-chan batchTask, resultCh chan<- batchResult, errorCh chan<- error, wg *sync.WaitGroup) { + defer wg.Done() + r.logger.Debug("embedding worker started", "worker", workerID) + + // Panic recovery to ensure worker doesn't crash silently + defer func() { + if rec := recover(); rec != nil { + r.logger.Error("embedding worker panicked", "worker", workerID, "panic", rec) + // Try to send error, but don't block if channel is full select { - case LongJobStatusCh <- ErrRAGStatus: + case errorCh <- fmt.Errorf("worker %d panicked: %v", workerID, rec): default: - r.logger.Warn("LongJobStatusCh channel full, dropping message") - } - return fmt.Errorf("failed to embed batch %d: %w", batchCount, err) - } - if len(embeddings) != len(nonEmptyBatch) { - err := errors.New("embedding count mismatch") - r.logger.Error("embedding mismatch", "expected", len(nonEmptyBatch), "got", len(embeddings)) - return err - } - // Write vectors to storage - filename := path.Base(fpath) - for j, text := range nonEmptyBatch { - vector := models.VectorRow{ - Embeddings: embeddings[j], - RawText: text, - Slug: fmt.Sprintf("%s_%d_%d", filename, batchCount, j), - FileName: filename, - } - if err := r.storage.WriteVector(&vector); err != nil { - r.logger.Error("failed to write vector to DB", "error", err, "slug", vector.Slug) - select { - case LongJobStatusCh <- ErrRAGStatus: - default: - r.logger.Warn("LongJobStatusCh channel full, dropping message") - } - return fmt.Errorf("failed to write vector: %w", err) + r.logger.Warn("error channel full, dropping panic error", "worker", workerID) } } - r.logger.Debug("wrote batch to db", "batch", batchCount, "size", len(nonEmptyBatch)) - // Send progress status - statusMsg := fmt.Sprintf("processed batch %d/%d", batchCount, (len(paragraphs)+r.cfg.RAGBatchSize-1)/r.cfg.RAGBatchSize) + }() + for task := range taskCh { select { - case LongJobStatusCh <- statusMsg: + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled", "worker", workerID) + return default: - r.logger.Warn("LongJobStatusCh channel full, dropping message") } + r.logger.Debug("worker processing batch", "worker", workerID, "batch", task.batchIndex, "paragraphs", len(task.paragraphs), "total_batches", task.totalBatches) + + // Skip empty batches + if len(task.paragraphs) == 0 { + select { + case resultCh <- batchResult{ + batchIndex: task.batchIndex, + embeddings: nil, + paragraphs: nil, + filename: task.filename, + }: + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled while sending empty batch", "worker", workerID) + return + } + r.logger.Debug("worker sent empty batch", "worker", workerID, "batch", task.batchIndex) + continue + } + // Embed with retry for API embedder + embeddings, err := r.embedWithRetry(ctx, task.paragraphs, 3) + if err != nil { + // Try to send error, but don't block indefinitely + select { + case errorCh <- fmt.Errorf("worker %d batch %d: %w", workerID, task.batchIndex, err): + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled while sending error", "worker", workerID) + } + return + } + // Send result with context awareness + select { + case resultCh <- batchResult{ + batchIndex: task.batchIndex, + embeddings: embeddings, + paragraphs: task.paragraphs, + filename: task.filename, + }: + case <-ctx.Done(): + r.logger.Debug("embedding worker cancelled while sending result", "worker", workerID) + return + } + r.logger.Debug("worker completed batch", "worker", workerID, "batch", task.batchIndex, "embeddings", len(embeddings)) } - r.logger.Debug("finished writing vectors", "batches", batchCount) + r.logger.Debug("embedding worker finished", "worker", workerID) +} + +// embedWithRetry attempts embedding with exponential backoff for API embedder +func (r *RAG) embedWithRetry(ctx context.Context, paragraphs []string, maxRetries int) ([][]float32, error) { + var lastErr error + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + // Exponential backoff + backoff := time.Duration(attempt*attempt) * time.Second + if backoff > 10*time.Second { + backoff = 10 * time.Second + } + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } + r.logger.Debug("retrying embedding", "attempt", attempt, "max_retries", maxRetries) + } + + embeddings, err := r.embedder.EmbedSlice(paragraphs) + if err == nil { + // Validate embedding count + if len(embeddings) != len(paragraphs) { + return nil, fmt.Errorf("embedding count mismatch: expected %d, got %d", len(paragraphs), len(embeddings)) + } + return embeddings, nil + } + + lastErr = err + // Only retry for API embedder errors (network/timeout) + // For ONNX embedder, fail fast + if _, isAPI := r.embedder.(*APIEmbedder); !isAPI { + break + } + } + return nil, fmt.Errorf("embedding failed after %d attempts: %w", maxRetries, lastErr) +} + +// writeBatchToStorage writes a single batch of vectors to the database +func (r *RAG) writeBatchToStorage(ctx context.Context, result batchResult, filename string) error { + if len(result.embeddings) == 0 { + // Empty batch, skip + return nil + } + // Check context before starting select { - case LongJobStatusCh <- FinishedRAGStatus: + case <-ctx.Done(): + return ctx.Err() default: - r.logger.Warn("LongJobStatusCh channel is full or closed, dropping status message", "message", FinishedRAGStatus) } + + // Build all vectors for batch write + vectors := make([]*models.VectorRow, 0, len(result.paragraphs)) + for j, text := range result.paragraphs { + vectors = append(vectors, &models.VectorRow{ + Embeddings: result.embeddings[j], + RawText: text, + Slug: fmt.Sprintf("%s_%d_%d", filename, result.batchIndex+1, j), + FileName: filename, + }) + } + + // Write all vectors in a single transaction + if err := r.storage.WriteVectors(vectors); err != nil { + r.logger.Error("failed to write vectors batch to DB", "error", err, "batch", result.batchIndex+1, "size", len(vectors)) + r.sendStatusNonBlocking(ErrRAGStatus) + return fmt.Errorf("failed to write vectors batch: %w", err) + } + r.logger.Debug("wrote batch to db", "batch", result.batchIndex+1, "size", len(result.paragraphs)) return nil } func (r *RAG) LineToVector(line string) ([]float32, error) { + r.resetIdleTimer() return r.embedder.Embed(line) } -func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { - return r.storage.SearchClosest(emb.Embedding) +func (r *RAG) searchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) { + r.resetIdleTimer() + return r.storage.SearchClosest(emb.Embedding, limit) +} + +func (r *RAG) searchKeyword(query string, limit int) ([]models.VectorRow, error) { + r.resetIdleTimer() + sanitized := sanitizeFTSQuery(query) + return r.storage.SearchKeyword(sanitized, limit) } func (r *RAG) ListLoaded() ([]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() return r.storage.ListFiles() } func (r *RAG) RemoveFile(filename string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.resetIdleTimer() return r.storage.RemoveEmbByFileName(filename) } var ( queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`) importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"} - stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"} ) func (r *RAG) RefineQuery(query string) string { @@ -210,11 +673,31 @@ func (r *RAG) RefineQuery(query string) string { if len(query) <= 3 { return original } + // If query already contains double quotes, assume it's a phrase query and skip refinement + if strings.Contains(query, "\"") { + return original + } query = strings.ToLower(query) - for _, stopWord := range stopWords { - wordPattern := `\b` + stopWord + `\b` - re := regexp.MustCompile(wordPattern) - query = re.ReplaceAllString(query, "") + words := strings.Fields(query) + if len(words) >= 3 { + // Detect phrases and protect words that are part of phrases + phrases := detectPhrases(query) + protectedWords := make(map[string]bool) + for _, phrase := range phrases { + for _, word := range strings.Fields(phrase) { + protectedWords[word] = true + } + } + + // Remove stop words that are not protected + for _, stopWord := range stopWords { + if protectedWords[stopWord] { + continue + } + wordPattern := `\b` + stopWord + `\b` + re := regexp.MustCompile(wordPattern) + query = re.ReplaceAllString(query, "") + } } query = strings.TrimSpace(query) if len(query) < 5 { @@ -246,7 +729,7 @@ func (r *RAG) extractImportantPhrases(query string) string { break } } - if isImportant || len(word) > 3 { + if isImportant || len(word) >= 3 { important = append(important, word) } } @@ -265,6 +748,36 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if len(parts) == 0 { return variations } + // Get loaded filenames to filter out filename terms + filenames, err := r.storage.ListFiles() + if err == nil && len(filenames) > 0 { + // Convert to lowercase for case-insensitive matching + lowerFilenames := make([]string, len(filenames)) + for i, f := range filenames { + lowerFilenames[i] = strings.ToLower(f) + } + filteredParts := make([]string, 0, len(parts)) + for _, part := range parts { + partLower := strings.ToLower(part) + skip := false + for _, fn := range lowerFilenames { + if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) { + skip = true + break + } + } + if !skip { + filteredParts = append(filteredParts, part) + } + } + // If filteredParts not empty and different from original, add filtered query + if len(filteredParts) > 0 && len(filteredParts) != len(parts) { + filteredQuery := strings.Join(filteredParts, " ") + if len(filteredQuery) >= 5 { + variations = append(variations, filteredQuery) + } + } + } if len(parts) >= 2 { trimmed := strings.Join(parts[:len(parts)-1], " ") if len(trimmed) >= 5 { @@ -289,13 +802,57 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if !strings.HasSuffix(query, " summary") { variations = append(variations, query+" summary") } + + // Add phrase-quoted variations for better FTS5 matching + phrases := detectPhrases(query) + if len(phrases) > 0 { + // Sort phrases by length descending to prioritize longer phrases + sort.Slice(phrases, func(i, j int) bool { + return len(phrases[i]) > len(phrases[j]) + }) + + // Create a version with all phrases quoted + quotedQuery := query + for _, phrase := range phrases { + // Only quote if not already quoted + quotedPhrase := "\"" + phrase + "\"" + if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) { + // Case-insensitive replacement of phrase with quoted version + re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`) + quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase) + } + } + // Disabled malformed quoted query for now + // if quotedQuery != query { + // variations = append(variations, quotedQuery) + // } + + // Also add individual phrase variations for short queries + if len(phrases) <= 5 { + for _, phrase := range phrases { + // Create a focused query with just this phrase quoted + // Keep original context but emphasize this phrase + quotedPhrase := "\"" + phrase + "\"" + re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`) + focusedQuery := re.ReplaceAllString(query, quotedPhrase) + if focusedQuery != query && focusedQuery != quotedQuery { + variations = append(variations, focusedQuery) + } + // Add the phrase alone (quoted) as a separate variation + variations = append(variations, quotedPhrase) + } + } + } + return variations } func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow { + phraseCount := len(detectPhrases(query)) type scoredResult struct { - row models.VectorRow - distance float32 + row models.VectorRow + distance float32 + phraseMatches int } scored := make([]scoredResult, 0, len(results)) for i := range results { @@ -320,27 +877,69 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { score += 3 } + + // Phrase match bonus: extra points for containing detected phrases + phraseMatches := countPhraseMatches(row.RawText, query) + if phraseMatches > 0 { + // Significant bonus per phrase to prioritize exact phrase matches + r.logger.Debug("phrase match bonus", "slug", row.Slug, "phraseMatches", phraseMatches, "score", score) + score += float32(phraseMatches) * 100 + } + + // Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results, + // boost score to promote narrative continuity + adjacentCount := 0 + for _, other := range results { + if other.Slug == row.Slug { + continue + } + if areSlugsAdjacent(row.Slug, other.Slug) { + adjacentCount++ + } + } + if adjacentCount > 0 { + // Bonus per adjacent chunk, but diminishing returns + score += float32(adjacentCount) * 4 + } distance := row.Distance - score/100 - scored = append(scored, scoredResult{row: row, distance: distance}) + scored = append(scored, scoredResult{row: row, distance: distance, phraseMatches: phraseMatches}) } sort.Slice(scored, func(i, j int) bool { return scored[i].distance < scored[j].distance }) unique := make([]models.VectorRow, 0) seen := make(map[string]bool) + maxPerFile := 2 + if phraseCount > 0 { + maxPerFile = 10 + } + fileCounts := make(map[string]int) for i := range scored { if !seen[scored[i].row.Slug] { + // Allow phrase-matching chunks to bypass per-file limit (up to +5 extra) + allowed := fileCounts[scored[i].row.FileName] < maxPerFile + if !allowed && scored[i].phraseMatches > 0 { + // If chunk has phrase matches, allow extra slots (up to maxPerFile + 5) + allowed = fileCounts[scored[i].row.FileName] < maxPerFile+5 + } + if !allowed { + continue + } seen[scored[i].row.Slug] = true + fileCounts[scored[i].row.FileName]++ unique = append(unique, scored[i].row) } } - if len(unique) > 10 { - unique = unique[:10] + if len(unique) > 30 { + unique = unique[:30] } return unique } func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + r.resetIdleTimer() if len(results) == 0 { return "No relevant information found in the vector database.", nil } @@ -369,7 +968,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp) + topResults, err := r.searchEmb(embResp, 1) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err @@ -396,9 +995,15 @@ func truncateString(s string, maxLen int) string { } func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { + r.mu.RLock() + defer r.mu.RUnlock() + r.resetIdleTimer() refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) - allResults := make([]models.VectorRow, 0) + r.logger.Debug("query variations", "original", query, "refined", refined, "variations", variations) + + // Collect embedding search results from all variations + var embResults []models.VectorRow seen := make(map[string]bool) for _, q := range variations { emb, err := r.LineToVector(q) @@ -406,29 +1011,119 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { r.logger.Error("failed to embed query variation", "error", err, "query", q) continue } - embResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, } - - results, err := r.SearchEmb(embResp) + results, err := r.searchEmb(embResp, limit*2) // Get more candidates if err != nil { r.logger.Error("failed to search embeddings", "error", err, "query", q) continue } - for _, row := range results { if !seen[row.Slug] { seen[row.Slug] = true - allResults = append(allResults, row) + embResults = append(embResults, row) } } } - reranked := r.RerankResults(allResults, query) - if len(reranked) > limit { - reranked = reranked[:limit] + // Sort embedding results by distance (lower is better) + sort.Slice(embResults, func(i, j int) bool { + return embResults[i].Distance < embResults[j].Distance + }) + + // Perform keyword search on all variations + var kwResults []models.VectorRow + seenKw := make(map[string]bool) + for _, q := range variations { + results, err := r.searchKeyword(q, limit) + if err != nil { + r.logger.Debug("keyword search failed for variation", "error", err, "query", q) + continue + } + for _, row := range results { + if !seenKw[row.Slug] { + seenKw[row.Slug] = true + kwResults = append(kwResults, row) + } + } + } + // Sort keyword results by distance (lower is better) + sort.Slice(kwResults, func(i, j int) bool { + return kwResults[i].Distance < kwResults[j].Distance + }) + + // Combine using Reciprocal Rank Fusion (RRF) + // Use smaller K for phrase-heavy queries to give more weight to top ranks + phraseCount := len(detectPhrases(query)) + rrfK := 60.0 + if phraseCount > 0 { + rrfK = 30.0 + } + r.logger.Debug("RRF parameters", "phraseCount", phraseCount, "rrfK", rrfK, "query", query) + type scoredRow struct { + row models.VectorRow + score float64 + } + scoreMap := make(map[string]float64) + // Add embedding results + for rank, row := range embResults { + score := 1.0 / (float64(rank) + rrfK) + scoreMap[row.Slug] += score + if row.Slug == "kjv_bible.epub_1786_0" { + r.logger.Debug("target chunk embedding rank", "rank", rank, "score", score) + } + } + // Add keyword results with weight boost when phrases are present + kwWeight := 1.0 + if phraseCount > 0 { + kwWeight = 100.0 + } + r.logger.Debug("keyword weight", "kwWeight", kwWeight, "phraseCount", phraseCount) + for rank, row := range kwResults { + score := kwWeight * (1.0 / (float64(rank) + rrfK)) + scoreMap[row.Slug] += score + if row.Slug == "kjv_bible.epub_1786_0" { + r.logger.Debug("target chunk keyword rank", "rank", rank, "score", score, "kwWeight", kwWeight, "rrfK", rrfK) + } + // Ensure row exists in combined results + if _, exists := seen[row.Slug]; !exists { + embResults = append(embResults, row) + } } + // Create slice of scored rows + scoredRows := make([]scoredRow, 0, len(embResults)) + for _, row := range embResults { + score := scoreMap[row.Slug] + scoredRows = append(scoredRows, scoredRow{row: row, score: score}) + } + // Debug: log scores for target chunk and top chunks + if strings.Contains(strings.ToLower(query), "bald") || strings.Contains(strings.ToLower(query), "she bears") { + for _, sr := range scoredRows { + if sr.row.Slug == "kjv_bible.epub_1786_0" { + r.logger.Debug("target chunk score", "slug", sr.row.Slug, "score", sr.score, "distance", sr.row.Distance) + } + } + // Log top 5 scores + for i := 0; i < len(scoredRows) && i < 5; i++ { + r.logger.Debug("top scored row", "rank", i+1, "slug", scoredRows[i].row.Slug, "score", scoredRows[i].score, "distance", scoredRows[i].row.Distance) + } + } + // Sort by descending RRF score + sort.Slice(scoredRows, func(i, j int) bool { + return scoredRows[i].score > scoredRows[j].score + }) + // Take top limit + if len(scoredRows) > limit { + scoredRows = scoredRows[:limit] + } + // Convert back to VectorRow + finalResults := make([]models.VectorRow, len(scoredRows)) + for i, sr := range scoredRows { + finalResults[i] = sr.row + } + // Apply reranking heuristics + reranked := r.RerankResults(finalResults, query) return reranked, nil } @@ -437,16 +1132,66 @@ var ( ragOnce sync.Once ) +func (r *RAG) FallbackMessage() string { + return r.fallbackMsg +} + func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error { + var err error ragOnce.Do(func() { if c == nil || l == nil || s == nil { return } - ragInstance = New(l, s, c) + ragInstance, err = New(l, s, c) }) - return nil + return err } func GetInstance() *RAG { return ragInstance } + +func (r *RAG) resetIdleTimer() { + r.idleMu.Lock() + defer r.idleMu.Unlock() + if r.idleTimer != nil { + r.idleTimer.Stop() + } + r.idleTimer = time.AfterFunc(r.idleTimeout, func() { + r.freeONNXMemory() + }) +} + +func (r *RAG) freeONNXMemory() { + r.mu.Lock() + defer r.mu.Unlock() + if onnx, ok := r.embedder.(*ONNXEmbedder); ok { + if err := onnx.Destroy(); err != nil { + r.logger.Error("failed to free ONNX memory", "error", err) + } else { + r.logger.Info("freed ONNX VRAM after idle timeout") + } + } +} + +func (r *RAG) Destroy() { + r.mu.Lock() + defer r.mu.Unlock() + if r.idleTimer != nil { + r.idleTimer.Stop() + r.idleTimer = nil + } + if onnx, ok := r.embedder.(*ONNXEmbedder); ok { + if err := onnx.Destroy(); err != nil { + r.logger.Error("failed to destroy ONNX embedder", "error", err) + } + } +} + +// SetEmbedderForTesting replaces the internal embedder with a mock. +// This function is only available when compiling with the "test" build tag. +func (r *RAG) SetEmbedderForTesting(e Embedder) { + r.mu.Lock() + defer r.mu.Unlock() + r.embedder = e +} |
