diff options
| author | Grail Finder <wohilas@gmail.com> | 2026-02-25 20:06:56 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2026-02-25 20:06:56 +0300 |
| commit | 888c9fec652b82174702c710f54f7d64f194315c (patch) | |
| tree | 883051d653dda2d57b227670bfd3721bf6cf426a /rag | |
| parent | 4f07994bdc3d23421cf3941af3edc18c05ffc94b (diff) | |
Chore: linter complaints
Diffstat (limited to 'rag')
| -rw-r--r-- | rag/embedder.go | 1 | ||||
| -rw-r--r-- | rag/extractors.go | 3 | ||||
| -rw-r--r-- | rag/rag.go | 55 | ||||
| -rw-r--r-- | rag/storage.go | 34 |
4 files changed, 10 insertions, 83 deletions
diff --git a/rag/embedder.go b/rag/embedder.go index bed1b41..1d29877 100644 --- a/rag/embedder.go +++ b/rag/embedder.go @@ -131,7 +131,6 @@ func (a *APIEmbedder) EmbedSlice(lines []string) ([][]float32, error) { } embeddings[data.Index] = data.Embedding } - return embeddings, nil } diff --git a/rag/extractors.go b/rag/extractors.go index 4255fdb..0f9f3f4 100644 --- a/rag/extractors.go +++ b/rag/extractors.go @@ -95,9 +95,7 @@ func extractTextFromEpub(fpath string) (string, error) { return "", fmt.Errorf("failed to open epub: %w", err) } defer r.Close() - var sb strings.Builder - for _, f := range r.File { ext := strings.ToLower(path.Ext(f.Name)) if ext != ".xhtml" && ext != ".html" && ext != ".htm" && ext != ".xml" { @@ -129,7 +127,6 @@ func extractTextFromEpub(fpath string) (string, error) { sb.WriteString(stripHTML(string(buf))) } } - if sb.Len() == 0 { return "", errors.New("no content extracted from epub") } @@ -36,7 +36,6 @@ type RAG struct { func New(l *slog.Logger, s storage.FullRepo, cfg *config.Config) *RAG { // Initialize with API embedder by default, could be configurable later embedder := NewAPIEmbedder(l, cfg) - rag := &RAG{ logger: l, store: s, @@ -205,29 +204,22 @@ var ( func (r *RAG) RefineQuery(query string) string { original := query query = strings.TrimSpace(query) - if len(query) == 0 { return original } - if len(query) <= 3 { return original } - query = strings.ToLower(query) - for _, stopWord := range stopWords { wordPattern := `\b` + stopWord + `\b` re := regexp.MustCompile(wordPattern) query = re.ReplaceAllString(query, "") } - query = strings.TrimSpace(query) - if len(query) < 5 { return original } - if queryRefinementPattern.MatchString(original) { cleaned := queryRefinementPattern.ReplaceAllString(original, "") cleaned = strings.TrimSpace(cleaned) @@ -235,23 +227,18 @@ func (r *RAG) RefineQuery(query string) string { return cleaned } } - query = r.extractImportantPhrases(query) - if len(query) < 5 { return original } - return query } func (r *RAG) extractImportantPhrases(query string) string { words := strings.Fields(query) - var important []string for _, word := range words { word = strings.Trim(word, ".,!?;:'\"()[]{}") - isImportant := false for _, kw := range importantKeywords { if strings.Contains(strings.ToLower(word), kw) { @@ -259,45 +246,37 @@ func (r *RAG) extractImportantPhrases(query string) string { break } } - if isImportant || len(word) > 3 { important = append(important, word) } } - if len(important) == 0 { return query } - return strings.Join(important, " ") } func (r *RAG) GenerateQueryVariations(query string) []string { variations := []string{query} - if len(query) < 5 { return variations } - parts := strings.Fields(query) if len(parts) == 0 { return variations } - if len(parts) >= 2 { trimmed := strings.Join(parts[:len(parts)-1], " ") if len(trimmed) >= 5 { variations = append(variations, trimmed) } } - if len(parts) >= 2 { trimmed := strings.Join(parts[1:], " ") if len(trimmed) >= 5 { variations = append(variations, trimmed) } } - if !strings.HasSuffix(query, " explanation") { variations = append(variations, query+" explanation") } @@ -310,7 +289,6 @@ func (r *RAG) GenerateQueryVariations(query string) []string { if !strings.HasSuffix(query, " summary") { variations = append(variations, query+" summary") } - return variations } @@ -319,21 +297,16 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V row models.VectorRow distance float32 } - scored := make([]scoredResult, 0, len(results)) - for i := range results { row := results[i] score := float32(0) - rawTextLower := strings.ToLower(row.RawText) queryLower := strings.ToLower(query) - if strings.Contains(rawTextLower, queryLower) { score += 10 } - queryWords := strings.Fields(queryLower) matchCount := 0 for _, word := range queryWords { @@ -344,34 +317,26 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V if len(queryWords) > 0 { score += float32(matchCount) / float32(len(queryWords)) * 5 } - if row.FileName == "chat" || strings.Contains(strings.ToLower(row.FileName), "conversation") { score += 3 } - distance := row.Distance - score/100 - scored = append(scored, scoredResult{row: row, distance: distance}) } - sort.Slice(scored, func(i, j int) bool { return scored[i].distance < scored[j].distance }) - unique := make([]models.VectorRow, 0) seen := make(map[string]bool) - for i := range scored { if !seen[scored[i].row.Slug] { seen[scored[i].row.Slug] = true unique = append(unique, scored[i].row) } } - if len(unique) > 10 { unique = unique[:10] } - return unique } @@ -379,58 +344,47 @@ func (r *RAG) SynthesizeAnswer(results []models.VectorRow, query string) (string if len(results) == 0 { return "No relevant information found in the vector database.", nil } - var contextBuilder strings.Builder contextBuilder.WriteString("User Query: ") contextBuilder.WriteString(query) contextBuilder.WriteString("\n\nRetrieved Context:\n") - for i, row := range results { - contextBuilder.WriteString(fmt.Sprintf("[Source %d: %s]\n", i+1, row.FileName)) + fmt.Fprintf(&contextBuilder, "[Source %d: %s]\n", i+1, row.FileName) contextBuilder.WriteString(row.RawText) contextBuilder.WriteString("\n\n") } - contextBuilder.WriteString("Instructions: ") contextBuilder.WriteString("Based on the retrieved context above, provide a concise, coherent answer to the user's query. ") contextBuilder.WriteString("Extract only the most relevant information. ") contextBuilder.WriteString("If no relevant information is found, state that clearly. ") contextBuilder.WriteString("Cite sources by filename when relevant. ") contextBuilder.WriteString("Do not include unnecessary preamble or explanations.") - synthesisPrompt := contextBuilder.String() - emb, err := r.LineToVector(synthesisPrompt) if err != nil { r.logger.Error("failed to embed synthesis prompt", "error", err) return "", err } - embResp := &models.EmbeddingResp{ Embedding: emb, Index: 0, } - topResults, err := r.SearchEmb(embResp) if err != nil { r.logger.Error("failed to search for synthesis context", "error", err) return "", err } - if len(topResults) > 0 && topResults[0].RawText != synthesisPrompt { return topResults[0].RawText, nil } - var finalAnswer strings.Builder finalAnswer.WriteString("Based on the retrieved context:\n\n") - for i, row := range results { if i >= 5 { break } - finalAnswer.WriteString(fmt.Sprintf("- From %s: %s\n", row.FileName, truncateString(row.RawText, 200))) + fmt.Fprintf(&finalAnswer, "- From %s: %s\n", row.FileName, truncateString(row.RawText, 200)) } - return finalAnswer.String(), nil } @@ -444,10 +398,8 @@ func truncateString(s string, maxLen int) string { func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { refined := r.RefineQuery(query) variations := r.GenerateQueryVariations(refined) - allResults := make([]models.VectorRow, 0) seen := make(map[string]bool) - for _, q := range variations { emb, err := r.LineToVector(q) if err != nil { @@ -473,13 +425,10 @@ func (r *RAG) Search(query string, limit int) ([]models.VectorRow, error) { } } } - reranked := r.RerankResults(allResults, query) - if len(reranked) > limit { reranked = reranked[:limit] } - return reranked, nil } diff --git a/rag/storage.go b/rag/storage.go index 782c504..52f6859 100644 --- a/rag/storage.go +++ b/rag/storage.go @@ -28,7 +28,6 @@ func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorag } } - // SerializeVector converts []float32 to binary blob func SerializeVector(vec []float32) []byte { buf := make([]byte, len(vec)*4) // 4 bytes per float32 @@ -66,17 +65,14 @@ func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { // Serialize the embeddings to binary serializedEmbeddings := SerializeVector(row.Embeddings) - query := fmt.Sprintf( "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName, ) - if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) return err } - return nil } @@ -86,20 +82,18 @@ func (vs *VectorStorage) getTableName(emb []float32) (string, error) { // Check if we support this embedding size supportedSizes := map[int]bool{ - 384: true, - 768: true, - 1024: true, - 1536: true, - 2048: true, - 3072: true, - 4096: true, - 5120: true, + 384: true, + 768: true, + 1024: true, + 1536: true, + 2048: true, + 3072: true, + 4096: true, + 5120: true, } - if supportedSizes[size] { return fmt.Sprintf("embeddings_%d", size), nil } - return "", fmt.Errorf("no table for embedding size of %d", size) } @@ -126,9 +120,7 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err vector models.VectorRow distance float32 } - var topResults []SearchResult - // Process vectors one by one to avoid loading everything into memory for rows.Next() { var ( @@ -176,14 +168,12 @@ func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, err result.vector.Distance = result.distance results = append(results, result.vector) } - return results, nil } // ListFiles returns a list of all loaded files func (vs *VectorStorage) ListFiles() ([]string, error) { fileLists := make([][]string, 0) - // Query all supported tables and combine results embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} for _, size := range embeddingSizes { @@ -219,14 +209,12 @@ func (vs *VectorStorage) ListFiles() ([]string, error) { } } } - return allFiles, nil } // RemoveEmbByFileName removes all embeddings associated with a specific filename func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { var errors []string - embeddingSizes := []int{384, 768, 1024, 1536, 2048, 3072, 4096, 5120} for _, size := range embeddingSizes { table := fmt.Sprintf("embeddings_%d", size) @@ -235,11 +223,9 @@ func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { errors = append(errors, err.Error()) } } - if len(errors) > 0 { return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; ")) } - return nil } @@ -248,18 +234,15 @@ func cosineSimilarity(a, b []float32) float32 { if len(a) != len(b) { return 0.0 } - var dotProduct, normA, normB float32 for i := 0; i < len(a); i++ { dotProduct += a[i] * b[i] normA += a[i] * a[i] normB += b[i] * b[i] } - if normA == 0 || normB == 0 { return 0.0 } - return dotProduct / (sqrt(normA) * sqrt(normB)) } @@ -275,4 +258,3 @@ func sqrt(f float32) float32 { } return guess } - |
