From 888c9fec652b82174702c710f54f7d64f194315c Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Wed, 25 Feb 2026 20:06:56 +0300 Subject: Chore: linter complaints --- rag/rag.go | 55 ++----------------------------------------------------- 1 file changed, 2 insertions(+), 53 deletions(-) (limited to 'rag/rag.go') diff --git a/rag/rag.go b/rag/rag.go index b8b5447..b63cb08 100644 --- a/rag/rag.go +++ b/rag/rag.go @@ -36,7 +36,6 @@ type RAG struct { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { // Initialize with API embedder by default, could be configurable later embedder := NewAPIEmbedder(l, cfg) - rag := &RAG{ logger: l, store: s, @@ -205,29 +204,22 @@ var ( func (r *RAG) RefineQuery(query string) string { original := query query = strings.TrimSpace(query) - if len(query) == 0 { return original } - if len(query) <= 3 { return original } - query = strings.ToLower(query) - for _, stopWord := range stopWords { wordPattern := `\b` + stopWord + `\b` re := regexp.MustCompile(wordPattern) query = re.ReplaceAllString(query, "") } - query = strings.TrimSpace(query) - if len(query) < 5 { return original } - if queryRefinementPattern.MatchString(original) { cleaned := queryRefinementPattern.ReplaceAllString(original, "") cleaned = strings.TrimSpace(cleaned) @@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string { return cleaned } } - query = r.extractImportantPhrases(query) - if len(query) < 5 { return original } - return query } func (r *RAG) extractImportantPhrases(query string) string { words := strings.Fields(query) - var important []string for _, word := range words { word = strings.Trim(word, ".,!?;:'\"()[]{}") - isImportant := false for _, kw := range importantKeywords { if strings.Contains(strings.ToLower(word), kw) { @@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string { break } } - if isImportant || len(word) > 3 { important = append(important, word) } } - if len(important) == 0 { return query } - return strings.Join(important, " ") } func (r *RAG) GenerateQueryVariations(query string) []string { variations := []string{query} - if len(query) < 5 { return variations } - parts := strings.Fields(query) if len(parts) == 0 { return variations } - if len(parts) >= 2 { trimmed := strings.Join(parts[:len(parts)-1], " ") if len(trimmed) >= 5 { variations = append(variations, trimmed) } } - if len(parts) >= 2 { trimmed := strings.Join(parts[1:], " ") if len(trimmed) >= 5 { variations = append(variations, trimmed) } } - if !strings.HasSuffix(query, " explanation") { variations = append(variations, query+" explanation") } @@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if !strings.HasSuffix(query, " summary") { variations = append(variations, query+" summary") } - return variations } @@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V row models.VectorRow distance float32 } - scored := make([]scoredResult, 0, len(results)) - for i := range results { row := results[i] score := float32(0) - rawTextLower := strings.ToLower(row.RawText) queryLower := strings.ToLower(query) - if strings.Contains(rawTextLower, queryLower) { score += 10 } - queryWords := strings.Fields(queryLower) matchCount := 0 for _, word := range queryWords { @@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V if len(queryWords) > 0 { score += float32(matchCount) / float32(len(queryWords)) * 5 } - if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { score += 3 } - distance := row.Distance - score/100 - scored = append(scored, scoredResult{row: row, distance: distance}) } - sort.Slice(scored, func(i, j int) bool { return scored[i].distance < scored[j].distance }) - unique := make([]models.VectorRow, 0) seen := make(map[string]bool) - for i := range scored { if !seen[scored[i].row.Slug] { seen[scored[i].row.Slug] = true unique = append(unique, scored[i].row) } } - if len(unique) > 10 { unique = unique[:10] } - return unique } @@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string if len(results) == 0 { return "No relevant information found in the vector database.", nil } - var contextBuilder strings.Builder contextBuilder.WriteString("User Query: ") contextBuilder.WriteString(query) contextBuilder.WriteString("\n\nRetrieved Context:\n") - for i, row := range results { - contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName)) + fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName) contextBuilder.WriteString(row.RawText) contextBuilder.WriteString("\n\n") } - contextBuilder.WriteString("Instructions: ") contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ") contextBuilder.WriteString("Extract only the most relevant information. ") contextBuilder.WriteString("If no relevant information is found, state that clearly. ") contextBuilder.WriteString("Cite sources by filename when relevant. ") contextBuilder.WriteString("Do not include unnecessary preamble or explanations.") - synthesisPrompt := contextBuilder.String() - emb, err := r.LineToVector(synthesisPrompt) if err != nil { r.logger.Error("failed to embed synthesis prompt", "error", err) return "", err } - embResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err } - if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt { return topResults[0].RawText, nil } - var finalAnswer strings.Builder finalAnswer.WriteString("Based on the retrieved context:\n\n") - for i, row := range results { if i >= 5 { break } - finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))) + fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)) } - return finalAnswer.String(), nil } @@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string { func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) - allResults := make([]models.VectorRow, 0) seen := make(map[string]bool) - for _, q := range variations { emb, err := r.LineToVector(q) if err != nil { @@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { } } } - reranked := r.RerankResults(allResults, query) - if len(reranked) > limit { reranked = reranked[:limit] } - return reranked, nil } -- cgit v1.2.3