From 25b2e2f592bd8df9a5cbd3c77322b572eb8f829c Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Wed, 19 Nov 2025 12:32:46 +0300 Subject: Fix: migration use of vec0; rag cleanup --- storage/vector.go | 185 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 155 insertions(+), 30 deletions(-) (limited to 'storage/vector.go') diff --git a/storage/vector.go b/storage/vector.go index 900803c..73bfe29 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -1,9 +1,9 @@ package storage import ( - "gf-lt/models" "encoding/binary" "fmt" + "gf-lt/models" "unsafe" "github.com/jmoiron/sqlx" @@ -26,7 +26,7 @@ func SerializeVector(vec []float32) []byte { return buf } -// DeserializeVector converts binary blob back to []float32 +// DeserializeVector converts binary blob back to []float32 func DeserializeVector(data []byte) []float32 { count := len(data) / 4 vec := make([]float32, count) @@ -66,50 +66,175 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { if err != nil { return err } - + serializedEmbeddings := SerializeVector(row.Embeddings) - - query := fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) + + query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) - + return err } - - func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { - // TODO: This function has been temporarily disabled to avoid deprecated library usage. - // In the new RAG implementation, this functionality is now in rag_new package. - // For compatibility, return empty result instead of using deprecated vector extension. - return []models.VectorRow{}, nil -} + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } -func (p ProviderSQL) ListFiles() ([]string, error) { - q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) - rows, err := p.db.Query(q) + querySQL := fmt.Sprintf("SELECT embedding, slug, raw_text, filename FROM %s", tableName) + rows, err := p.db.Query(querySQL) if err != nil { return nil, err } defer rows.Close() - - resp := []string{} + + type SearchResult struct { + vector models.VectorRow + distance float32 + } + + var topResults []SearchResult + for rows.Next() { - var filename string - if err := rows.Scan(&filename); err != nil { - return nil, err + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(q, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + + // Add to top results and maintain only top results + topResults = append(topResults, result) + + // Sort and keep only top results + // We'll keep the top 3 closest vectors + if len(topResults) > 3 { + // Simple sort and truncate to maintain only 3 best matches + for i := 0; i < len(topResults); i++ { + for j := i + 1; j < len(topResults); j++ { + if topResults[i].distance > topResults[j].distance { + topResults[i], topResults[j] = topResults[j], topResults[i] + } + } + } + topResults = topResults[:3] } - resp = append(resp, filename) } - - if err := rows.Err(); err != nil { - return nil, err + + // Convert back to VectorRow slice + results := make([]models.VectorRow, len(topResults)) + for i, result := range topResults { + result.vector.Distance = result.distance + results[i] = result.vector } - - return resp, nil + + return results, nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 + } + return guess +} + +func (p ProviderSQL) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query both tables and combine results + for _, table := range []string{vecTableName384, vecTableName5120} { + query := fmt.Sprintf("SELECT DISTINCT filename FROM %s", table) + rows, err := p.db.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } + } + + return allFiles, nil } func (p ProviderSQL) RemoveEmbByFileName(filename string) error { - q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) - _, err := p.db.Exec(q, filename) - return err + var errors []string + + for _, table := range []string{vecTableName384, vecTableName5120} { + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := p.db.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %s", fmt.Sprintf("%v", errors)) + } + + return nil } -- cgit v1.2.3