From 6c03a1a277e486c877eb8d632b5a0ed321ece73e Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Tue, 24 Feb 2026 20:24:44 +0300 Subject: Feat: rag tool --- rag/rag.go | 308 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) (limited to 'rag') diff --git a/rag/rag.go b/rag/rag.go index f554924..b49bd97 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -9,6 +9,8 @@ import ( "log/slog" "os" "path" + "regexp" + "sort" "strings" "sync" @@ -195,3 +197,309 @@ func (r *RAG) ListLoaded() ([]string, error) { func (r *RAG) RemoveFile(filename string) error { return r.storage.RemoveEmbByFileName(filename) } + +var ( + queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`) + importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"} + stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"} +) + +func (r *RAG) RefineQuery(query string) string { + original := query + query = strings.TrimSpace(query) + + if len(query) == 0 { + return original + } + + if len(query) <= 3 { + return original + } + + query = strings.ToLower(query) + + for _, stopWord := range stopWords { + wordPattern := `\b` + stopWord + `\b` + re := regexp.MustCompile(wordPattern) + query = re.ReplaceAllString(query, "") + } + + query = strings.TrimSpace(query) + + if len(query) < 5 { + return original + } + + if queryRefinementPattern.MatchString(original) { + cleaned := queryRefinementPattern.ReplaceAllString(original, "") + cleaned = strings.TrimSpace(cleaned) + if len(cleaned) >= 5 { + return cleaned + } + } + + query = r.extractImportantPhrases(query) + + if len(query) < 5 { + return original + } + + return query +} + +func (r *RAG) extractImportantPhrases(query string) string { + words := strings.Fields(query) + + var important []string + for _, word := range words { + word = strings.Trim(word, ".,!?;:'\"()[]{}") + + isImportant := false + for _, kw := range importantKeywords { + if strings.Contains(strings.ToLower(word), kw) { + isImportant = true + break + } + } + + if isImportant || len(word) > 3 { + important = append(important, word) + } + } + + if len(important) == 0 { + return query + } + + return strings.Join(important, " ") +} + +func (r *RAG) GenerateQueryVariations(query string) []string { + variations := []string{query} + + if len(query) < 5 { + return variations + } + + parts := strings.Fields(query) + if len(parts) == 0 { + return variations + } + + if len(parts) >= 2 { + trimmed := strings.Join(parts[:len(parts)-1], " ") + if len(trimmed) >= 5 { + variations = append(variations, trimmed) + } + } + + if len(parts) >= 2 { + trimmed := strings.Join(parts[1:], " ") + if len(trimmed) >= 5 { + variations = append(variations, trimmed) + } + } + + if !strings.HasSuffix(query, " explanation") { + variations = append(variations, query+" explanation") + } + if !strings.HasPrefix(query, "what is ") { + variations = append(variations, "what is "+query) + } + if !strings.HasSuffix(query, " details") { + variations = append(variations, query+" details") + } + if !strings.HasSuffix(query, " summary") { + variations = append(variations, query+" summary") + } + + return variations +} + +func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow { + type scoredResult struct { + row models.VectorRow + distance float32 + } + + scored := make([]scoredResult, 0, len(results)) + + for i := range results { + row := results[i] + + score := float32(0) + + rawTextLower := strings.ToLower(row.RawText) + queryLower := strings.ToLower(query) + + if strings.Contains(rawTextLower, queryLower) { + score += 10 + } + + queryWords := strings.Fields(queryLower) + matchCount := 0 + for _, word := range queryWords { + if len(word) > 2 && strings.Contains(rawTextLower, word) { + matchCount++ + } + } + if len(queryWords) > 0 { + score += float32(matchCount) / float32(len(queryWords)) * 5 + } + + if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { + score += 3 + } + + distance := row.Distance - score/100 + + scored = append(scored, scoredResult{row: row, distance: distance}) + } + + sort.Slice(scored, func(i, j int) bool { + return scored[i].distance < scored[j].distance + }) + + unique := make([]models.VectorRow, 0) + seen := make(map[string]bool) + + for i := range scored { + if !seen[scored[i].row.Slug] { + seen[scored[i].row.Slug] = true + unique = append(unique, scored[i].row) + } + } + + if len(unique) > 10 { + unique = unique[:10] + } + + return unique +} + +func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) { + if len(results) == 0 { + return "No relevant information found in the vector database.", nil + } + + var contextBuilder strings.Builder + contextBuilder.WriteString("User Query: ") + contextBuilder.WriteString(query) + contextBuilder.WriteString("\n\nRetrieved Context:\n") + + for i, row := range results { + contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName)) + contextBuilder.WriteString(row.RawText) + contextBuilder.WriteString("\n\n") + } + + contextBuilder.WriteString("Instructions: ") + contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ") + contextBuilder.WriteString("Extract only the most relevant information. ") + contextBuilder.WriteString("If no relevant information is found, state that clearly. ") + contextBuilder.WriteString("Cite sources by filename when relevant. ") + contextBuilder.WriteString("Do not include unnecessary preamble or explanations.") + + synthesisPrompt := contextBuilder.String() + + emb, err := r.LineToVector(synthesisPrompt) + if err != nil { + r.logger.Error("failed to embed synthesis prompt", "error", err) + return "", err + } + + embResp := &models.EmbeddingResp{ + Embedding: emb, + Index: 0, + } + + topResults, err := r.SearchEmb(embResp) + if err != nil { + r.logger.Error("failed to search for synthesis context", "error", err) + return "", err + } + + if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt { + return topResults[0].RawText, nil + } + + var finalAnswer strings.Builder + finalAnswer.WriteString("Based on the retrieved context:\n\n") + + for i, row := range results { + if i >= 5 { + break + } + finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))) + } + + return finalAnswer.String(), nil +} + +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { + refined := r.RefineQuery(query) + variations := r.GenerateQueryVariations(refined) + + allResults := make([]models.VectorRow, 0) + seen := make(map[string]bool) + + for _, q := range variations { + emb, err := r.LineToVector(q) + if err != nil { + r.logger.Error("failed to embed query variation", "error", err, "query", q) + continue + } + + embResp := &models.EmbeddingResp{ + Embedding: emb, + Index: 0, + } + + results, err := r.SearchEmb(embResp) + if err != nil { + r.logger.Error("failed to search embeddings", "error", err, "query", q) + continue + } + + for _, row := range results { + if !seen[row.Slug] { + seen[row.Slug] = true + allResults = append(allResults, row) + } + } + } + + reranked := r.RerankResults(allResults, query) + + if len(reranked) > limit { + reranked = reranked[:limit] + } + + return reranked, nil +} + +var ( + ragInstance *RAG + ragOnce sync.Once +) + +func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error { + ragOnce.Do(func() { + if c == nil || l == nil || s == nil { + return + } + ragInstance = New(l, s, c) + }) + return nil +} + +func GetInstance() *RAG { + return ragInstance +} -- cgit v1.2.3