diff options
Diffstat (limited to 'rag/rag.go')
| -rw-r--r-- | rag/rag.go | 136 |
1 files changed, 116 insertions, 20 deletions
@@ -74,6 +74,22 @@ func detectPhrases(query string) []string { return phrases } +// countPhraseMatches returns the number of query phrases found in text +func countPhraseMatches(text, query string) int { + phrases := detectPhrases(query) + if len(phrases) == 0 { + return 0 + } + textLower := strings.ToLower(text) + count := 0 + for _, phrase := range phrases { + if strings.Contains(textLower, phrase) { + count++ + } + } + return count +} + // parseSlugIndices extracts batch and chunk indices from a slug // slug format: filename_batch_chunk (e.g., "kjv_bible.epub_1786_0") func parseSlugIndices(slug string) (batch, chunk int, ok bool) { @@ -120,6 +136,9 @@ func areSlugsAdjacent(slug1, slug2 string) bool { // Check if they're in sequential batches and chunk indices suggest continuity // This is heuristic but useful for cross-batch adjacency + if (batch1 == batch2+1 && chunk1 == 0) || (batch2 == batch1+1 && chunk2 == 0) { + return true + } return false } @@ -654,6 +673,10 @@ func (r *RAG) RefineQuery(query string) string { if len(query) <= 3 { return original } + // If query already contains double quotes, assume it's a phrase query and skip refinement + if strings.Contains(query, "\"") { + return original + } query = strings.ToLower(query) words := strings.Fields(query) if len(words) >= 3 { @@ -799,12 +822,13 @@ func (r *RAG) GenerateQueryVariations(query string) []string { quotedQuery = re.ReplaceAllString(quotedQuery, quotedPhrase) } } - if quotedQuery != query { - variations = append(variations, quotedQuery) - } + // Disabled malformed quoted query for now + // if quotedQuery != query { + // variations = append(variations, quotedQuery) + // } // Also add individual phrase variations for short queries - if len(phrases) <= 3 { + if len(phrases) <= 5 { for _, phrase := range phrases { // Create a focused query with just this phrase quoted // Keep original context but emphasize this phrase @@ -814,6 +838,8 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if focusedQuery != query && focusedQuery != quotedQuery { variations = append(variations, focusedQuery) } + // Add the phrase alone (quoted) as a separate variation + variations = append(variations, quotedPhrase) } } } @@ -822,9 +848,11 @@ func (r *RAG) GenerateQueryVariations(query string) []string { } func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow { + phraseCount := len(detectPhrases(query)) type scoredResult struct { - row models.VectorRow - distance float32 + row models.VectorRow + distance float32 + phraseMatches int } scored := make([]scoredResult, 0, len(results)) for i := range results { @@ -850,6 +878,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V score += 3 } + // Phrase match bonus: extra points for containing detected phrases + phraseMatches := countPhraseMatches(row.RawText, query) + if phraseMatches > 0 { + // Significant bonus per phrase to prioritize exact phrase matches + r.logger.Debug("phrase match bonus", "slug", row.Slug, "phraseMatches", phraseMatches, "score", score) + score += float32(phraseMatches) * 100 + } + // Cross-chunk adjacency bonus: if this chunk has adjacent siblings in results, // boost score to promote narrative continuity adjacentCount := 0 @@ -866,17 +902,27 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V score += float32(adjacentCount) * 4 } distance := row.Distance - score/100 - scored = append(scored, scoredResult{row: row, distance: distance}) + scored = append(scored, scoredResult{row: row, distance: distance, phraseMatches: phraseMatches}) } sort.Slice(scored, func(i, j int) bool { return scored[i].distance < scored[j].distance }) unique := make([]models.VectorRow, 0) seen := make(map[string]bool) + maxPerFile := 2 + if phraseCount > 0 { + maxPerFile = 10 + } fileCounts := make(map[string]int) for i := range scored { if !seen[scored[i].row.Slug] { - if fileCounts[scored[i].row.FileName] >= 2 { + // Allow phrase-matching chunks to bypass per-file limit (up to +5 extra) + allowed := fileCounts[scored[i].row.FileName] < maxPerFile + if !allowed && scored[i].phraseMatches > 0 { + // If chunk has phrase matches, allow extra slots (up to maxPerFile + 5) + allowed = fileCounts[scored[i].row.FileName] < maxPerFile+5 + } + if !allowed { continue } seen[scored[i].row.Slug] = true @@ -884,8 +930,8 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V unique = append(unique, scored[i].row) } } - if len(unique) > 10 { - unique = unique[:10] + if len(unique) > 30 { + unique = unique[:30] } return unique } @@ -954,6 +1000,7 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { r.resetIdleTimer() refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) + r.logger.Debug("query variations", "original", query, "refined", refined, "variations", variations) // Collect embedding search results from all variations var embResults []models.VectorRow @@ -985,17 +1032,35 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { return embResults[i].Distance < embResults[j].Distance }) - // Perform keyword search - kwResults, err := r.searchKeyword(refined, limit*2) - if err != nil { - r.logger.Warn("keyword search failed, using only embeddings", "error", err) - kwResults = nil + // Perform keyword search on all variations + var kwResults []models.VectorRow + seenKw := make(map[string]bool) + for _, q := range variations { + results, err := r.searchKeyword(q, limit) + if err != nil { + r.logger.Debug("keyword search failed for variation", "error", err, "query", q) + continue + } + for _, row := range results { + if !seenKw[row.Slug] { + seenKw[row.Slug] = true + kwResults = append(kwResults, row) + } + } } - // Sort keyword results by distance (already sorted by BM25 score) - // kwResults already sorted by distance (lower is better) + // Sort keyword results by distance (lower is better) + sort.Slice(kwResults, func(i, j int) bool { + return kwResults[i].Distance < kwResults[j].Distance + }) // Combine using Reciprocal Rank Fusion (RRF) - const rrfK = 60 + // Use smaller K for phrase-heavy queries to give more weight to top ranks + phraseCount := len(detectPhrases(query)) + rrfK := 60.0 + if phraseCount > 0 { + rrfK = 30.0 + } + r.logger.Debug("RRF parameters", "phraseCount", phraseCount, "rrfK", rrfK, "query", query) type scoredRow struct { row models.VectorRow score float64 @@ -1005,11 +1070,22 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { for rank, row := range embResults { score := 1.0 / (float64(rank) + rrfK) scoreMap[row.Slug] += score + if row.Slug == "kjv_bible.epub_1786_0" { + r.logger.Debug("target chunk embedding rank", "rank", rank, "score", score) + } } - // Add keyword results + // Add keyword results with weight boost when phrases are present + kwWeight := 1.0 + if phraseCount > 0 { + kwWeight = 100.0 + } + r.logger.Debug("keyword weight", "kwWeight", kwWeight, "phraseCount", phraseCount) for rank, row := range kwResults { - score := 1.0 / (float64(rank) + rrfK) + score := kwWeight * (1.0 / (float64(rank) + rrfK)) scoreMap[row.Slug] += score + if row.Slug == "kjv_bible.epub_1786_0" { + r.logger.Debug("target chunk keyword rank", "rank", rank, "score", score, "kwWeight", kwWeight, "rrfK", rrfK) + } // Ensure row exists in combined results if _, exists := seen[row.Slug]; !exists { embResults = append(embResults, row) @@ -1021,6 +1097,18 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { score := scoreMap[row.Slug] scoredRows = append(scoredRows, scoredRow{row: row, score: score}) } + // Debug: log scores for target chunk and top chunks + if strings.Contains(strings.ToLower(query), "bald") || strings.Contains(strings.ToLower(query), "she bears") { + for _, sr := range scoredRows { + if sr.row.Slug == "kjv_bible.epub_1786_0" { + r.logger.Debug("target chunk score", "slug", sr.row.Slug, "score", sr.score, "distance", sr.row.Distance) + } + } + // Log top 5 scores + for i := 0; i < len(scoredRows) && i < 5; i++ { + r.logger.Debug("top scored row", "rank", i+1, "slug", scoredRows[i].row.Slug, "score", scoredRows[i].score, "distance", scoredRows[i].row.Distance) + } + } // Sort by descending RRF score sort.Slice(scoredRows, func(i, j int) bool { return scoredRows[i].score > scoredRows[j].score @@ -1099,3 +1187,11 @@ func (r *RAG) Destroy() { } } } + +// SetEmbedderForTesting replaces the internal embedder with a mock. +// This function is only available when compiling with the "test" build tag. +func (r *RAG) SetEmbedderForTesting(e Embedder) { + r.mu.Lock() + defer r.mu.Unlock() + r.embedder = e +} |
