diff options
Diffstat (limited to 'rag/rag.go')
| -rw-r--r-- | rag/rag.go | 55 |
1 files changed, 2 insertions, 53 deletions
@@ -36,7 +36,6 @@ type RAG struct { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { // Initialize with API embedder by default, could be configurable later embedder := NewAPIEmbedder(l, cfg) - rag := &RAG{ logger: l, store: s, @@ -205,29 +204,22 @@ var ( func (r *RAG) RefineQuery(query string) string { original := query query = strings.TrimSpace(query) - if len(query) == 0 { return original } - if len(query) <= 3 { return original } - query = strings.ToLower(query) - for _, stopWord := range stopWords { wordPattern := `\b` + stopWord + `\b` re := regexp.MustCompile(wordPattern) query = re.ReplaceAllString(query, "") } - query = strings.TrimSpace(query) - if len(query) < 5 { return original } - if queryRefinementPattern.MatchString(original) { cleaned := queryRefinementPattern.ReplaceAllString(original, "") cleaned = strings.TrimSpace(cleaned) @@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string { return cleaned } } - query = r.extractImportantPhrases(query) - if len(query) < 5 { return original } - return query } func (r *RAG) extractImportantPhrases(query string) string { words := strings.Fields(query) - var important []string for _, word := range words { word = strings.Trim(word, ".,!?;:'\"()[]{}") - isImportant := false for _, kw := range importantKeywords { if strings.Contains(strings.ToLower(word), kw) { @@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string { break } } - if isImportant || len(word) > 3 { important = append(important, word) } } - if len(important) == 0 { return query } - return strings.Join(important, " ") } func (r *RAG) GenerateQueryVariations(query string) []string { variations := []string{query} - if len(query) < 5 { return variations } - parts := strings.Fields(query) if len(parts) == 0 { return variations } - if len(parts) >= 2 { trimmed := strings.Join(parts[:len(parts)-1], " ") if len(trimmed) >= 5 { variations = append(variations, trimmed) } } - if len(parts) >= 2 { trimmed := strings.Join(parts[1:], " ") if len(trimmed) >= 5 { variations = append(variations, trimmed) } } - if !strings.HasSuffix(query, " explanation") { variations = append(variations, query+" explanation") } @@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if !strings.HasSuffix(query, " summary") { variations = append(variations, query+" summary") } - return variations } @@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V row models.VectorRow distance float32 } - scored := make([]scoredResult, 0, len(results)) - for i := range results { row := results[i] score := float32(0) - rawTextLower := strings.ToLower(row.RawText) queryLower := strings.ToLower(query) - if strings.Contains(rawTextLower, queryLower) { score += 10 } - queryWords := strings.Fields(queryLower) matchCount := 0 for _, word := range queryWords { @@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V if len(queryWords) > 0 { score += float32(matchCount) / float32(len(queryWords)) * 5 } - if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { score += 3 } - distance := row.Distance - score/100 - scored = append(scored, scoredResult{row: row, distance: distance}) } - sort.Slice(scored, func(i, j int) bool { return scored[i].distance < scored[j].distance }) - unique := make([]models.VectorRow, 0) seen := make(map[string]bool) - for i := range scored { if !seen[scored[i].row.Slug] { seen[scored[i].row.Slug] = true unique = append(unique, scored[i].row) } } - if len(unique) > 10 { unique = unique[:10] } - return unique } @@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string if len(results) == 0 { return "No relevant information found in the vector database.", nil } - var contextBuilder strings.Builder contextBuilder.WriteString("User Query: ") contextBuilder.WriteString(query) contextBuilder.WriteString("\n\nRetrieved Context:\n") - for i, row := range results { - contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName)) + fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName) contextBuilder.WriteString(row.RawText) contextBuilder.WriteString("\n\n") } - contextBuilder.WriteString("Instructions: ") contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ") contextBuilder.WriteString("Extract only the most relevant information. ") contextBuilder.WriteString("If no relevant information is found, state that clearly. ") contextBuilder.WriteString("Cite sources by filename when relevant. ") contextBuilder.WriteString("Do not include unnecessary preamble or explanations.") - synthesisPrompt := contextBuilder.String() - emb, err := r.LineToVector(synthesisPrompt) if err != nil { r.logger.Error("failed to embed synthesis prompt", "error", err) return "", err } - embResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err } - if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt { return topResults[0].RawText, nil } - var finalAnswer strings.Builder finalAnswer.WriteString("Based on the retrieved context:\n\n") - for i, row := range results { if i >= 5 { break } - finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))) + fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)) } - return finalAnswer.String(), nil } @@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string { func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) - allResults := make([]models.VectorRow, 0) seen := make(map[string]bool) - for _, q := range variations { emb, err := r.LineToVector(q) if err != nil { @@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { } } } - reranked := r.RerankResults(allResults, query) - if len(reranked) > limit { reranked = reranked[:limit] } - return reranked, nil } |
