summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
Diffstat (limited to 'rag')
-rw-r--r--rag/rag.go308
1 files changed, 308 insertions, 0 deletions
diff --git a/rag/rag.go b/rag/rag.go
index f554924..b49bd97 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -9,6 +9,8 @@ import (
"log/slog"
"os"
"path"
+ "regexp"
+ "sort"
"strings"
"sync"
@@ -195,3 +197,309 @@ func (r *RAG) ListLoaded() ([]string, error) {
func (r *RAG) RemoveFile(filename string) error {
return r.storage.RemoveEmbByFileName(filename)
}
+
+var (
+ queryRefinementPattern = regexp.MustCompile(`(?i)(based on my (vector db|vector db|vector database|rags?|past (conversations?|chat|messages?))|from my (files?|documents?|data|information|memory)|search (in|my) (vector db|database|rags?)|rag search for)`)
+ importantKeywords = []string{"project", "architecture", "code", "file", "chat", "conversation", "topic", "summary", "details", "history", "previous", "my", "user", "me"}
+ stopWords = []string{"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "down", "left", "right"}
+)
+
+func (r *RAG) RefineQuery(query string) string {
+ original := query
+ query = strings.TrimSpace(query)
+
+ if len(query) == 0 {
+ return original
+ }
+
+ if len(query) <= 3 {
+ return original
+ }
+
+ query = strings.ToLower(query)
+
+ for _, stopWord := range stopWords {
+ wordPattern := `\b` + stopWord + `\b`
+ re := regexp.MustCompile(wordPattern)
+ query = re.ReplaceAllString(query, "")
+ }
+
+ query = strings.TrimSpace(query)
+
+ if len(query) < 5 {
+ return original
+ }
+
+ if queryRefinementPattern.MatchString(original) {
+ cleaned := queryRefinementPattern.ReplaceAllString(original, "")
+ cleaned = strings.TrimSpace(cleaned)
+ if len(cleaned) >= 5 {
+ return cleaned
+ }
+ }
+
+ query = r.extractImportantPhrases(query)
+
+ if len(query) < 5 {
+ return original
+ }
+
+ return query
+}
+
+func (r *RAG) extractImportantPhrases(query string) string {
+ words := strings.Fields(query)
+
+ var important []string
+ for _, word := range words {
+ word = strings.Trim(word, ".,!?;:'\"()[]{}")
+
+ isImportant := false
+ for _, kw := range importantKeywords {
+ if strings.Contains(strings.ToLower(word), kw) {
+ isImportant = true
+ break
+ }
+ }
+
+ if isImportant || len(word) > 3 {
+ important = append(important, word)
+ }
+ }
+
+ if len(important) == 0 {
+ return query
+ }
+
+ return strings.Join(important, " ")
+}
+
+func (r *RAG) GenerateQueryVariations(query string) []string {
+ variations := []string{query}
+
+ if len(query) < 5 {
+ return variations
+ }
+
+ parts := strings.Fields(query)
+ if len(parts) == 0 {
+ return variations
+ }
+
+ if len(parts) >= 2 {
+ trimmed := strings.Join(parts[:len(parts)-1], " ")
+ if len(trimmed) >= 5 {
+ variations = append(variations, trimmed)
+ }
+ }
+
+ if len(parts) >= 2 {
+ trimmed := strings.Join(parts[1:], " ")
+ if len(trimmed) >= 5 {
+ variations = append(variations, trimmed)
+ }
+ }
+
+ if !strings.HasSuffix(query, " explanation") {
+ variations = append(variations, query+" explanation")
+ }
+ if !strings.HasPrefix(query, "what is ") {
+ variations = append(variations, "what is "+query)
+ }
+ if !strings.HasSuffix(query, " details") {
+ variations = append(variations, query+" details")
+ }
+ if !strings.HasSuffix(query, " summary") {
+ variations = append(variations, query+" summary")
+ }
+
+ return variations
+}
+
+func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.VectorRow {
+ type scoredResult struct {
+ row models.VectorRow
+ distance float32
+ }
+
+ scored := make([]scoredResult, 0, len(results))
+
+ for i := range results {
+ row := results[i]
+
+ score := float32(0)
+
+ rawTextLower := strings.ToLower(row.RawText)
+ queryLower := strings.ToLower(query)
+
+ if strings.Contains(rawTextLower, queryLower) {
+ score += 10
+ }
+
+ queryWords := strings.Fields(queryLower)
+ matchCount := 0
+ for _, word := range queryWords {
+ if len(word) > 2 && strings.Contains(rawTextLower, word) {
+ matchCount++
+ }
+ }
+ if len(queryWords) > 0 {
+ score += float32(matchCount) / float32(len(queryWords)) * 5
+ }
+
+ if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") {
+ score += 3
+ }
+
+ distance := row.Distance - score/100
+
+ scored = append(scored, scoredResult{row: row, distance: distance})
+ }
+
+ sort.Slice(scored, func(i, j int) bool {
+ return scored[i].distance < scored[j].distance
+ })
+
+ unique := make([]models.VectorRow, 0)
+ seen := make(map[string]bool)
+
+ for i := range scored {
+ if !seen[scored[i].row.Slug] {
+ seen[scored[i].row.Slug] = true
+ unique = append(unique, scored[i].row)
+ }
+ }
+
+ if len(unique) > 10 {
+ unique = unique[:10]
+ }
+
+ return unique
+}
+
+func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string, error) {
+ if len(results) == 0 {
+ return "No relevant information found in the vector database.", nil
+ }
+
+ var contextBuilder strings.Builder
+ contextBuilder.WriteString("User Query: ")
+ contextBuilder.WriteString(query)
+ contextBuilder.WriteString("\n\nRetrieved Context:\n")
+
+ for i, row := range results {
+ contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName))
+ contextBuilder.WriteString(row.RawText)
+ contextBuilder.WriteString("\n\n")
+ }
+
+ contextBuilder.WriteString("Instructions: ")
+ contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ")
+ contextBuilder.WriteString("Extract only the most relevant information. ")
+ contextBuilder.WriteString("If no relevant information is found, state that clearly. ")
+ contextBuilder.WriteString("Cite sources by filename when relevant. ")
+ contextBuilder.WriteString("Do not include unnecessary preamble or explanations.")
+
+ synthesisPrompt := contextBuilder.String()
+
+ emb, err := r.LineToVector(synthesisPrompt)
+ if err != nil {
+ r.logger.Error("failed to embed synthesis prompt", "error", err)
+ return "", err
+ }
+
+ embResp := &models.EmbeddingResp{
+ Embedding: emb,
+ Index: 0,
+ }
+
+ topResults, err := r.SearchEmb(embResp)
+ if err != nil {
+ r.logger.Error("failed to search for synthesis context", "error", err)
+ return "", err
+ }
+
+ if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt {
+ return topResults[0].RawText, nil
+ }
+
+ var finalAnswer strings.Builder
+ finalAnswer.WriteString("Based on the retrieved context:\n\n")
+
+ for i, row := range results {
+ if i >= 5 {
+ break
+ }
+ finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)))
+ }
+
+ return finalAnswer.String(), nil
+}
+
+func truncateString(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}
+
+func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) {
+ refined := r.RefineQuery(query)
+ variations := r.GenerateQueryVariations(refined)
+
+ allResults := make([]models.VectorRow, 0)
+ seen := make(map[string]bool)
+
+ for _, q := range variations {
+ emb, err := r.LineToVector(q)
+ if err != nil {
+ r.logger.Error("failed to embed query variation", "error", err, "query", q)
+ continue
+ }
+
+ embResp := &models.EmbeddingResp{
+ Embedding: emb,
+ Index: 0,
+ }
+
+ results, err := r.SearchEmb(embResp)
+ if err != nil {
+ r.logger.Error("failed to search embeddings", "error", err, "query", q)
+ continue
+ }
+
+ for _, row := range results {
+ if !seen[row.Slug] {
+ seen[row.Slug] = true
+ allResults = append(allResults, row)
+ }
+ }
+ }
+
+ reranked := r.RerankResults(allResults, query)
+
+ if len(reranked) > limit {
+ reranked = reranked[:limit]
+ }
+
+ return reranked, nil
+}
+
+var (
+ ragInstance *RAG
+ ragOnce sync.Once
+)
+
+func Init(c *config.Config, l *slog.Logger, s storage.FullRepo) error {
+ ragOnce.Do(func() {
+ if c == nil || l == nil || s == nil {
+ return
+ }
+ ragInstance = New(l, s, c)
+ })
+ return nil
+}
+
+func GetInstance() *RAG {
+ return ragInstance
+}