diff options
Diffstat (limited to 'rag/rag.go')
| -rw-r--r-- | rag/rag.go | 176 |
1 files changed, 139 insertions, 37 deletions
@@ -73,6 +73,74 @@ func wordCounter(sentence string) int { return len(strings.Split(strings.TrimSpace(sentence), " ")) } +func createChunks(sentences []string, wordLimit, overlapWords uint32) []string { + if len(sentences) == 0 { + return nil + } + if overlapWords >= wordLimit { + overlapWords = wordLimit / 2 + } + var chunks []string + i := 0 + for i < len(sentences) { + var chunkWords []string + wordCount := 0 + j := i + for j < len(sentences) && wordCount <= int(wordLimit) { + sentence := sentences[j] + words := strings.Fields(sentence) + chunkWords = append(chunkWords, sentence) + wordCount += len(words) + j++ + // If this sentence alone exceeds limit, still include it and stop + if wordCount > int(wordLimit) { + break + } + } + if len(chunkWords) == 0 { + break + } + chunk := strings.Join(chunkWords, " ") + chunks = append(chunks, chunk) + if j >= len(sentences) { + break + } + // Move i forward by skipping overlap + if overlapWords == 0 { + i = j + continue + } + // Calculate how many sentences to skip to achieve overlapWords + overlapRemaining := int(overlapWords) + newI := i + for newI < j && overlapRemaining > 0 { + words := len(strings.Fields(sentences[newI])) + overlapRemaining -= words + if overlapRemaining >= 0 { + newI++ + } + } + if newI == i { + newI = j + } + i = newI + } + return chunks +} + +func sanitizeFTSQuery(query string) string { + // Remove double quotes and other problematic characters for FTS5 + query = strings.ReplaceAll(query, "\"", " ") + query = strings.ReplaceAll(query, "'", " ") + query = strings.ReplaceAll(query, ";", " ") + query = strings.ReplaceAll(query, "\\", " ") + query = strings.TrimSpace(query) + if query == "" { + return "*" // match all + } + return query +} + func (r *RAG) LoadRAG(fpath string) error { r.mu.Lock() defer r.mu.Unlock() @@ -95,31 +163,8 @@ func (r *RAG) LoadRAG(fpath string) error { for i, s := range sentences { sents[i] = s.Text } - // Group sentences into paragraphs based on word limit - paragraphs := []string{} - par := strings.Builder{} - for i := 0; i < len(sents); i++ { - if strings.TrimSpace(sents[i]) != "" { - if par.Len() > 0 { - par.WriteString(" ") - } - par.WriteString(sents[i]) - } - if wordCounter(par.String()) > int(r.cfg.RAGWordLimit) { - paragraph := strings.TrimSpace(par.String()) - if paragraph != "" { - paragraphs = append(paragraphs, paragraph) - } - par.Reset() - } - } - // Handle any remaining content in the paragraph buffer - if par.Len() > 0 { - paragraph := strings.TrimSpace(par.String()) - if paragraph != "" { - paragraphs = append(paragraphs, paragraph) - } - } + // Create chunks with overlap + paragraphs := createChunks(sents, r.cfg.RAGWordLimit, r.cfg.RAGOverlapWords) // Adjust batch size if needed if len(paragraphs) < r.cfg.RAGBatchSize && len(paragraphs) > 0 { r.cfg.RAGBatchSize = len(paragraphs) @@ -205,9 +250,15 @@ func (r *RAG) LineToVector(line string) ([]float32, error) { return r.embedder.Embed(line) } -func (r *RAG) SearchEmb(emb *models.EmbeddingResp) ([]models.VectorRow, error) { +func (r *RAG) SearchEmb(emb *models.EmbeddingResp, limit int) ([]models.VectorRow, error) { r.resetIdleTimer() - return r.storage.SearchClosest(emb.Embedding) + return r.storage.SearchClosest(emb.Embedding, limit) +} + +func (r *RAG) SearchKeyword(query string, limit int) ([]models.VectorRow, error) { + r.resetIdleTimer() + sanitized := sanitizeFTSQuery(query) + return r.storage.SearchKeyword(sanitized, limit) } func (r *RAG) ListLoaded() ([]string, error) { @@ -393,7 +444,7 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp) + topResults, err := r.SearchEmb(embResp, 1) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err @@ -422,7 +473,9 @@ func truncateString(s string, maxLen int) string { func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) - allResults := make([]models.VectorRow, 0) + + // Collect embedding search results from all variations + var embResults []models.VectorRow seen := make(map[string]bool) for _, q := range variations { emb, err := r.LineToVector(q) @@ -430,29 +483,78 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { r.logger.Error("failed to embed query variation", "error", err, "query", q) continue } - embResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, } - - results, err := r.SearchEmb(embResp) + results, err := r.SearchEmb(embResp, limit*2) // Get more candidates if err != nil { r.logger.Error("failed to search embeddings", "error", err, "query", q) continue } - for _, row := range results { if !seen[row.Slug] { seen[row.Slug] = true - allResults = append(allResults, row) + embResults = append(embResults, row) } } } - reranked := r.RerankResults(allResults, query) - if len(reranked) > limit { - reranked = reranked[:limit] + // Sort embedding results by distance (lower is better) + sort.Slice(embResults, func(i, j int) bool { + return embResults[i].Distance < embResults[j].Distance + }) + + // Perform keyword search + kwResults, err := r.SearchKeyword(refined, limit*2) + if err != nil { + r.logger.Warn("keyword search failed, using only embeddings", "error", err) + kwResults = nil + } + // Sort keyword results by distance (already sorted by BM25 score) + // kwResults already sorted by distance (lower is better) + + // Combine using Reciprocal Rank Fusion (RRF) + const rrfK = 60 + type scoredRow struct { + row models.VectorRow + score float64 + } + scoreMap := make(map[string]float64) + // Add embedding results + for rank, row := range embResults { + score := 1.0 / (float64(rank) + rrfK) + scoreMap[row.Slug] += score + } + // Add keyword results + for rank, row := range kwResults { + score := 1.0 / (float64(rank) + rrfK) + scoreMap[row.Slug] += score + // Ensure row exists in combined results + if _, exists := seen[row.Slug]; !exists { + embResults = append(embResults, row) + } + } + // Create slice of scored rows + scoredRows := make([]scoredRow, 0, len(embResults)) + for _, row := range embResults { + score := scoreMap[row.Slug] + scoredRows = append(scoredRows, scoredRow{row: row, score: score}) + } + // Sort by descending RRF score + sort.Slice(scoredRows, func(i, j int) bool { + return scoredRows[i].score > scoredRows[j].score + }) + // Take top limit + if len(scoredRows) > limit { + scoredRows = scoredRows[:limit] + } + // Convert back to VectorRow + finalResults := make([]models.VectorRow, len(scoredRows)) + for i, sr := range scoredRows { + finalResults[i] = sr.row } + // Apply reranking heuristics + reranked := r.RerankResults(finalResults, query) return reranked, nil } |
