summaryrefslogtreecommitdiff
path: root/rag
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-03-06 13:17:49 +0300
committerGrail Finder <wohilas@gmail.com>2026-03-06 13:17:49 +0300
commit62ec55505ca07701ee6a976895d910b051e725b9 (patch)
tree0c7da8d65d398e27fabe9876a5bccf9fc70010d5 /rag
parentf9866bcf5a7369e28246d51b951e81b5b2a8489f (diff)
Enha (rag): query each doc
Diffstat (limited to 'rag')
-rw-r--r--rag/rag.go46
-rw-r--r--rag/storage.go33
2 files changed, 74 insertions, 5 deletions
diff --git a/rag/rag.go b/rag/rag.go
index 4e11a0d..9271b60 100644
--- a/rag/rag.go
+++ b/rag/rag.go
@@ -286,10 +286,13 @@ func (r *RAG) RefineQuery(query string) string {
return original
}
query = strings.ToLower(query)
- for _, stopWord := range stopWords {
- wordPattern := `\b` + stopWord + `\b`
- re := regexp.MustCompile(wordPattern)
- query = re.ReplaceAllString(query, "")
+ words := strings.Fields(query)
+ if len(words) >= 3 {
+ for _, stopWord := range stopWords {
+ wordPattern := `\b` + stopWord + `\b`
+ re := regexp.MustCompile(wordPattern)
+ query = re.ReplaceAllString(query, "")
+ }
}
query = strings.TrimSpace(query)
if len(query) < 5 {
@@ -340,6 +343,36 @@ func (r *RAG) GenerateQueryVariations(query string) []string {
if len(parts) == 0 {
return variations
}
+ // Get loaded filenames to filter out filename terms
+ filenames, err := r.storage.ListFiles()
+ if err == nil && len(filenames) > 0 {
+ // Convert to lowercase for case-insensitive matching
+ lowerFilenames := make([]string, len(filenames))
+ for i, f := range filenames {
+ lowerFilenames[i] = strings.ToLower(f)
+ }
+ filteredParts := make([]string, 0, len(parts))
+ for _, part := range parts {
+ partLower := strings.ToLower(part)
+ skip := false
+ for _, fn := range lowerFilenames {
+ if strings.Contains(fn, partLower) || strings.Contains(partLower, fn) {
+ skip = true
+ break
+ }
+ }
+ if !skip {
+ filteredParts = append(filteredParts, part)
+ }
+ }
+ // If filteredParts not empty and different from original, add filtered query
+ if len(filteredParts) > 0 && len(filteredParts) != len(parts) {
+ filteredQuery := strings.Join(filteredParts, " ")
+ if len(filteredQuery) >= 5 {
+ variations = append(variations, filteredQuery)
+ }
+ }
+ }
if len(parts) >= 2 {
trimmed := strings.Join(parts[:len(parts)-1], " ")
if len(trimmed) >= 5 {
@@ -403,9 +436,14 @@ func (r *RAG) RerankResults(results []models.VectorRow, query string) []models.V
})
unique := make([]models.VectorRow, 0)
seen := make(map[string]bool)
+ fileCounts := make(map[string]int)
for i := range scored {
if !seen[scored[i].row.Slug] {
+ if fileCounts[scored[i].row.FileName] >= 2 {
+ continue
+ }
seen[scored[i].row.Slug] = true
+ fileCounts[scored[i].row.FileName]++
unique = append(unique, scored[i].row)
}
}
diff --git a/rag/storage.go b/rag/storage.go
index 08e9d2a..110cea2 100644
--- a/rag/storage.go
+++ b/rag/storage.go
@@ -1,6 +1,7 @@
package rag
import (
+ "database/sql"
"encoding/binary"
"fmt"
"gf-lt/models"
@@ -221,11 +222,41 @@ func (vs *VectorStorage) SearchKeyword(query string, limit int) ([]models.Vector
WHERE fts_embeddings MATCH ?
ORDER BY score
LIMIT ?`
+
+ // Try original query first
rows, err := vs.sqlxDB.Query(ftsQuery, query, limit)
if err != nil {
return nil, fmt.Errorf("FTS search failed: %w", err)
}
- defer rows.Close()
+ results, err := vs.scanRows(rows)
+ rows.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ // If no results and query contains multiple terms, try OR fallback
+ if len(results) == 0 && strings.Contains(query, " ") && !strings.Contains(strings.ToUpper(query), " OR ") {
+ // Build OR query: term1 OR term2 OR term3
+ terms := strings.Fields(query)
+ if len(terms) > 1 {
+ orQuery := strings.Join(terms, " OR ")
+ rows, err := vs.sqlxDB.Query(ftsQuery, orQuery, limit)
+ if err != nil {
+ // Return original empty results rather than error
+ return results, nil
+ }
+ orResults, err := vs.scanRows(rows)
+ rows.Close()
+ if err == nil {
+ results = orResults
+ }
+ }
+ }
+ return results, nil
+}
+
+// scanRows converts SQL rows to VectorRow slice
+func (vs *VectorStorage) scanRows(rows *sql.Rows) ([]models.VectorRow, error) {
var results []models.VectorRow
for rows.Next() {
var slug, rawText, fileName string