summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rag/rag.go167
1 files changed, 164 insertions, 3 deletions
diff --git a/rag/rag.go b/rag/rag.go
index ef85e7f..6f12dd9 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -12,6 +12,7 @@ import (
"regexp"
"runtime"
"sort"
+ "strconv"
"strings"
"sync"
"time"
@@ -27,8 +28,101 @@ var (
FinishedRAGStatus = "finished loading RAG file; press x to exit"
LoadedFileRAGStatus = "loaded file"
ErrRAGStatus = "some error occurred; failed to transfer data to vector db"
+
+ // stopWords are common words that can be removed from queries when not part of phrases
+ stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right", "about", "like", "such", "than", "then", "also", "too"}
)
+// isStopWord checks if a word is in the stop words list
+func isStopWord(word string) bool {
+ for _, stop := range stopWords {
+ if strings.EqualFold(word, stop) {
+ return true
+ }
+ }
+ return false
+}
+
+// detectPhrases returns multi-word phrases from a query that should be treated as units
+func detectPhrases(query string) []string {
+ words := strings.Fields(strings.ToLower(query))
+ var phrases []string
+
+ for i := 0; i < len(words)-1; i++ {
+ word1 := strings.Trim(words[i], ".,!?;:'\"()[]{}")
+ word2 := strings.Trim(words[i+1], ".,!?;:'\"()[]{}")
+
+ // Skip if either word is a stop word or too short
+ if isStopWord(word1) || isStopWord(word2) || len(word1) < 2 || len(word2) < 2 {
+ continue
+ }
+
+ // Check if this pair appears to be a meaningful phrase
+ // Simple heuristic: consecutive non-stop words of reasonable length
+ phrase := word1 + " " + word2
+ phrases = append(phrases, phrase)
+
+ // Optionally check for 3-word phrases
+ if i < len(words)-2 {
+ word3 := strings.Trim(words[i+2], ".,!?;:'\"()[]{}")
+ if !isStopWord(word3) && len(word3) >= 2 {
+ phrases = append(phrases, word1+" "+word2+" "+word3)
+ }
+ }
+ }
+
+ return phrases
+}
+
+// parseSlugIndices extracts batch and chunk indices from a slug
+// slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0")
+func parseSlugIndices(slug string) (batch, chunk int, ok bool) {
+ // Find the last two numbers separated by underscores
+ re := regexp.MustCompile(`_(\d+)_(\d+)$`)
+ matches := re.FindStringSubmatch(slug)
+ if matches == nil || len(matches) != 3 {
+ return 0, 0, false
+ }
+ batch, err1 := strconv.Atoi(matches[1])
+ chunk, err2 := strconv.Atoi(matches[2])
+ if err1 != nil || err2 != nil {
+ return 0, 0, false
+ }
+ return batch, chunk, true
+}
+
+// areSlugsAdjacent returns true if two slugs are from the same file and have sequential indices
+func areSlugsAdjacent(slug1, slug2 string) bool {
+ // Extract filename prefix (everything before the last underscore sequence)
+ parts1 := strings.Split(slug1, "_")
+ parts2 := strings.Split(slug2, "_")
+ if len(parts1) < 3 || len(parts2) < 3 {
+ return false
+ }
+
+ // Compare filename prefixes (all parts except last two)
+ prefix1 := strings.Join(parts1[:len(parts1)-2], "_")
+ prefix2 := strings.Join(parts2[:len(parts2)-2], "_")
+ if prefix1 != prefix2 {
+ return false
+ }
+
+ batch1, chunk1, ok1 := parseSlugIndices(slug1)
+ batch2, chunk2, ok2 := parseSlugIndices(slug2)
+ if !ok1 || !ok2 {
+ return false
+ }
+
+ // Check if they're in same batch and chunks are sequential
+ if batch1 == batch2 && (chunk1 == chunk2+1 || chunk2 == chunk1+1) {
+ return true
+ }
+
+ // Check if they're in sequential batches and chunk indices suggest continuity
+ // This is heuristic but useful for cross-batch adjacency
+ return false
+}
+
type RAG struct {
logger *slog.Logger
store storage.FullRepo
@@ -155,8 +249,8 @@ func createChunks(sentences []string, wordLimit, overlapWords uint32) []string {
}
func sanitizeFTSQuery(query string) string {
- // Remove double quotes and other problematic characters for FTS5
- // query = strings.ReplaceAll(query, "\"", " ")
+ // Keep double quotes for FTS5 phrase matching
+ // Remove other problematic characters
query = strings.ReplaceAll(query, "'", " ")
query = strings.ReplaceAll(query, ";", " ")
query = strings.ReplaceAll(query, "\\", " ")
@@ -549,7 +643,6 @@ func (r *RAG) RemoveFile(filename string) error {
var (
queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
- stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
)
func (r *RAG) RefineQuery(query string) string {
@@ -564,7 +657,20 @@ func (r *RAG) RefineQuery(query string) string {
query = strings.ToLower(query)
words := strings.Fields(query)
if len(words) >= 3 {
+ // Detect phrases and protect words that are part of phrases
+ phrases := detectPhrases(query)
+ protectedWords := make(map[string]bool)
+ for _, phrase := range phrases {
+ for _, word := range strings.Fields(phrase) {
+ protectedWords[word] = true
+ }
+ }
+
+ // Remove stop words that are not protected
for _, stopWord := range stopWords {
+ if protectedWords[stopWord] {
+ continue
+ }
wordPattern := `\b` + stopWord + `\b`
re := regexp.MustCompile(wordPattern)
query = re.ReplaceAllString(query, "")
@@ -673,6 +779,45 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if !strings.HasSuffix(query, " summary") {
variations = append(variations, query+" summary")
}
+
+ // Add phrase-quoted variations for better FTS5 matching
+ phrases := detectPhrases(query)
+ if len(phrases) > 0 {
+ // Sort phrases by length descending to prioritize longer phrases
+ sort.Slice(phrases, func(i, j int) bool {
+ return len(phrases[i]) > len(phrases[j])
+ })
+
+ // Create a version with all phrases quoted
+ quotedQuery := query
+ for _, phrase := range phrases {
+ // Only quote if not already quoted
+ quotedPhrase := "\"" + phrase + "\""
+ if !strings.Contains(strings.ToLower(quotedQuery), strings.ToLower(quotedPhrase)) {
+ // Case-insensitive replacement of phrase with quoted version
+ re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
+ quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase)
+ }
+ }
+ if quotedQuery != query {
+ variations = append(variations, quotedQuery)
+ }
+
+ // Also add individual phrase variations for short queries
+ if len(phrases) <= 3 {
+ for _, phrase := range phrases {
+ // Create a focused query with just this phrase quoted
+ // Keep original context but emphasize this phrase
+ quotedPhrase := "\"" + phrase + "\""
+ re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(phrase) + `\b`)
+ focusedQuery := re.ReplaceAllString(query, quotedPhrase)
+ if focusedQuery != query && focusedQuery != quotedQuery {
+ variations = append(variations, focusedQuery)
+ }
+ }
+ }
+ }
+
return variations
}
@@ -704,6 +849,22 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
score += 3
}
+
+ // Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results,
+ // boost score to promote narrative continuity
+ adjacentCount := 0
+ for _, other := range results {
+ if other.Slug == row.Slug {
+ continue
+ }
+ if areSlugsAdjacent(row.Slug, other.Slug) {
+ adjacentCount++
+ }
+ }
+ if adjacentCount > 0 {
+ // Bonus per adjacent chunk, but diminishing returns
+ score += float32(adjacentCount) * 4
+ }
distance := row.Distance - score/100
scored = append(scored, scoredResult{row: row, distance: distance})
}