diff options
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/rag.go | 167 |
1 files changed, 164 insertions, 3 deletions
@@ -12,6 +12,7 @@ import ( "regexp" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -27,8 +28,101 @@ var ( FinishedRAGStatus = "finished loading RAG file; press x to exit" LoadedFileRAGStatus = "loaded file" ErrRAGStatus = "some error occurred; failed to transfer data to vector db" + + // stopWords are common words that can be removed from queries when not part of phrases + stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"} ) +// isStopWord checks if a word is in the stop words list +func isStopWord(word string) bool { + for _, stop := range stopWords { + if strings.EqualFold(word, stop) { + return true + } + } + return false +} + +// detectPhrases returns multi-word phrases from a query that should be treated as units +func detectPhrases(query string) []string { + words := strings.Fields(strings.ToLower(query)) + var phrases []string + + for i := 0; i < len(words)-1; i++ { + word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}") + word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}") + + // Skip if either word is a stop word or too short + if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 { + continue + } + + // Check if this pair appears to be a meaningful phrase + // Simple heuristic: consecutive non-stop words of reasonable length + phrase := word1 + " " + word2 + phrases = append(phrases, phrase) + + // Optionally check for 3-word phrases + if i < len(words)-2 { + word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}") + if !isStopWord(word3) && len(word3) >= 2 { + phrases = append(phrases, word1+" "+word2+" "+word3) + } + } + } + + return phrases +} + +// parseSlugIndices extracts batch and chunk indices from a slug +// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0") +func parseSlugIndices(slug string) (batch, chunk int, ok bool) { + // Find the last two numbers separated by underscores + re := regexp.MustCompile(`_(\d+)_(\d+)$`) + matches := re.FindStringSubmatch(slug) + if matches == nil || len(matches) != 3 { + return 0, 0, false + } + batch, err1 := strconv.Atoi(matches[1]) + chunk, err2 := strconv.Atoi(matches[2]) + if err1 != nil || err2 != nil { + return 0, 0, false + } + return batch, chunk, true +} + +// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices +func areSlugsAdjacent(slug1, slug2 string) bool { + // Extract filename prefix (everything before the last underscore sequence) + parts1 := strings.Split(slug1, "_") + parts2 := strings.Split(slug2, "_") + if len(parts1) < 3 || len(parts2) < 3 { + return false + } + + // Compare filename prefixes (all parts except last two) + prefix1 := strings.Join(parts1[:len(parts1)-2], "_") + prefix2 := strings.Join(parts2[:len(parts2)-2], "_") + if prefix1 != prefix2 { + return false + } + + batch1, chunk1, ok1 := parseSlugIndices(slug1) + batch2, chunk2, ok2 := parseSlugIndices(slug2) + if !ok1 || !ok2 { + return false + } + + // Check if they're in same batch and chunks are sequential + if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) { + return true + } + + // Check if they're in sequential batches and chunk indices suggest continuity + // This is heuristic but useful for cross-batch adjacency + return false +} + type RAG struct { logger *slog.Logger store storage.FullRepo @@ -155,8 +249,8 @@ func createChunks(sentences []string, wordLimit, overlapWords uint32) []string { } func sanitizeFTSQuery(query string) string { - // Remove double quotes and other problematic characters for FTS5 - // query = strings.ReplaceAll(query, "\"", " ") + // Keep double quotes for FTS5 phrase matching + // Remove other problematic characters query = strings.ReplaceAll(query, "'", " ") query = strings.ReplaceAll(query, ";", " ") query = strings.ReplaceAll(query, "\\", " ") @@ -549,7 +643,6 @@ func (r *RAG) RemoveFile(filename string) error { var ( queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`) importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"} - stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"} ) func (r *RAG) RefineQuery(query string) string { @@ -564,7 +657,20 @@ func (r *RAG) RefineQuery(query string) string { query = strings.ToLower(query) words := strings.Fields(query) if len(words) >= 3 { + // Detect phrases and protect words that are part of phrases + phrases := detectPhrases(query) + protectedWords := make(map[string]bool) + for _, phrase := range phrases { + for _, word := range strings.Fields(phrase) { + protectedWords[word] = true + } + } + + // Remove stop words that are not protected for _, stopWord := range stopWords { + if protectedWords[stopWord] { + continue + } wordPattern := `\b` + stopWord + `\b` re := regexp.MustCompile(wordPattern) query = re.ReplaceAllString(query, "") @@ -673,6 +779,45 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if !strings.HasSuffix(query, " summary") { variations = append(variations, query+" summary") } + + // Add phrase-quoted variations for better FTS5 matching + phrases := detectPhrases(query) + if len(phrases) > 0 { + // Sort phrases by length descending to prioritize longer phrases + sort.Slice(phrases, func(i, j int) bool { + return len(phrases[i]) > len(phrases[j]) + }) + + // Create a version with all phrases quoted + quotedQuery := query + for _, phrase := range phrases { + // Only quote if not already quoted + quotedPhrase := "\"" + phrase + "\"" + if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) { + // Case-insensitive replacement of phrase with quoted version + re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`) + quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase) + } + } + if quotedQuery != query { + variations = append(variations, quotedQuery) + } + + // Also add individual phrase variations for short queries + if len(phrases) <= 3 { + for _, phrase := range phrases { + // Create a focused query with just this phrase quoted + // Keep original context but emphasize this phrase + quotedPhrase := "\"" + phrase + "\"" + re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`) + focusedQuery := re.ReplaceAllString(query, quotedPhrase) + if focusedQuery != query && focusedQuery != quotedQuery { + variations = append(variations, focusedQuery) + } + } + } + } + return variations } @@ -704,6 +849,22 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { score += 3 } + + // Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results, + // boost score to promote narrative continuity + adjacentCount := 0 + for _, other := range results { + if other.Slug == row.Slug { + continue + } + if areSlugsAdjacent(row.Slug, other.Slug) { + adjacentCount++ + } + } + if adjacentCount > 0 { + // Bonus per adjacent chunk, but diminishing returns + score += float32(adjacentCount) * 4 + } distance := row.Distance - score/100 scored = append(scored, scoredResult{row: row, distance: distance}) } |
