package rag_new import ( "gf-lt/models" "gf-lt/storage" "encoding/binary" "fmt" "log/slog" "sort" "strings" "unsafe" "github.com/jmoiron/sqlx" ) // VectorStorage handles storing and retrieving vectors from SQLite type VectorStorage struct { logger *slog.Logger sqlxDB *sqlx.DB store storage.FullRepo } func NewVectorStorage(logger *slog.Logger, store storage.FullRepo) *VectorStorage { return &VectorStorage{ logger: logger, sqlxDB: store.DB(), // Use the new DB() method store: store, } } // CreateTables creates the necessary tables for vector storage func (vs *VectorStorage) CreateTables() error { // Create tables for different embedding dimensions queries := []string{ `CREATE TABLE IF NOT EXISTS embeddings_384 ( id INTEGER PRIMARY KEY AUTOINCREMENT, embeddings BLOB NOT NULL, slug TEXT NOT NULL, raw_text TEXT NOT NULL, filename TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`, `CREATE TABLE IF NOT EXISTS embeddings_5120 ( id INTEGER PRIMARY KEY AUTOINCREMENT, embeddings BLOB NOT NULL, slug TEXT NOT NULL, raw_text TEXT NOT NULL, filename TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`, // Indexes for better performance `CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename)`, `CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename)`, `CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug)`, `CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug)`, // Additional indexes that may help with searches `CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at)`, `CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at)`, } for _, query := range queries { if _, err := vs.sqlxDB.Exec(query); err != nil { return fmt.Errorf("failed to create table: %w", err) } } return nil } // SerializeVector converts []float32 to binary blob func SerializeVector(vec []float32) []byte { buf := make([]byte, len(vec)*4) // 4 bytes per float32 for i, v := range vec { binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) } return buf } // DeserializeVector converts binary blob back to []float32 func DeserializeVector(data []byte) []float32 { count := len(data) / 4 vec := make([]float32, count) for i := 0; i < count; i++ { vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) } return vec } // mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 func mathFloat32bits(f float32) uint32 { return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) } func mathBitsToFloat32(b uint32) float32 { return *(*float32)(unsafe.Pointer(&b)) } // WriteVector stores an embedding vector in the database func (vs *VectorStorage) WriteVector(row *models.VectorRow) error { tableName, err := vs.getTableName(row.Embeddings) if err != nil { return err } // Serialize the embeddings to binary serializedEmbeddings := SerializeVector(row.Embeddings) query := fmt.Sprintf( "INSERT INTO %s (embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName, ) if _, err := vs.sqlxDB.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName); err != nil { vs.logger.Error("failed to write vector", "error", err, "slug", row.Slug) return err } return nil } // getTableName determines which table to use based on embedding size func (vs *VectorStorage) getTableName(emb []float32) (string, error) { switch len(emb) { case 384: return "embeddings_384", nil case 5120: return "embeddings_5120", nil default: return "", fmt.Errorf("no table for embedding size of %d", len(emb)) } } // SearchClosest finds vectors closest to the query vector using efficient cosine similarity calculation func (vs *VectorStorage) SearchClosest(query []float32) ([]models.VectorRow, error) { tableName, err := vs.getTableName(query) if err != nil { return nil, err } // For better performance, instead of loading all vectors at once, // we'll implement batching and potentially add L2 distance-based pre-filtering // since cosine similarity is related to L2 distance for normalized vectors querySQL := fmt.Sprintf("SELECT embeddings, slug, raw_text, filename FROM %s", tableName) rows, err := vs.sqlxDB.Query(querySQL) if err != nil { return nil, err } defer rows.Close() // Use a min-heap or simple slice to keep track of top 3 closest vectors type SearchResult struct { vector models.VectorRow distance float32 } var topResults []SearchResult // Process vectors one by one to avoid loading everything into memory for rows.Next() { var ( embeddingsBlob []byte slug, rawText, fileName string ) if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { vs.logger.Error("failed to scan row", "error", err) continue } storedEmbeddings := DeserializeVector(embeddingsBlob) // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) similarity := cosineSimilarity(query, storedEmbeddings) distance := 1 - similarity // Convert to distance where 0 is most similar result := SearchResult{ vector: models.VectorRow{ Embeddings: storedEmbeddings, Slug: slug, RawText: rawText, FileName: fileName, }, distance: distance, } // Add to top results and maintain only top 3 topResults = append(topResults, result) // Sort and keep only top 3 sort.Slice(topResults, func(i, j int) bool { return topResults[i].distance < topResults[j].distance }) if len(topResults) > 3 { topResults = topResults[:3] // Keep only closest 3 } } // Convert back to VectorRow slice var results []models.VectorRow for _, result := range topResults { result.vector.Distance = result.distance results = append(results, result.vector) } return results, nil } // ListFiles returns a list of all loaded files func (vs *VectorStorage) ListFiles() ([]string, error) { var fileLists [][]string // Query both tables and combine results for _, table := range []string{"embeddings_384", "embeddings_5120"} { query := fmt.Sprintf("SELECT DISTINCT filename FROM %s", table) rows, err := vs.sqlxDB.Query(query) if err != nil { // Continue if one table doesn't exist continue } var files []string for rows.Next() { var filename string if err := rows.Scan(&filename); err != nil { continue } files = append(files, filename) } rows.Close() fileLists = append(fileLists, files) } // Combine and deduplicate fileSet := make(map[string]bool) var allFiles []string for _, files := range fileLists { for _, file := range files { if !fileSet[file] { fileSet[file] = true allFiles = append(allFiles, file) } } } return allFiles, nil } // RemoveEmbByFileName removes all embeddings associated with a specific filename func (vs *VectorStorage) RemoveEmbByFileName(filename string) error { var errors []string for _, table := range []string{"embeddings_384", "embeddings_5120"} { query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) if _, err := vs.sqlxDB.Exec(query, filename); err != nil { errors = append(errors, err.Error()) } } if len(errors) > 0 { return fmt.Errorf("errors occurred: %s", strings.Join(errors, "; ")) } return nil } // cosineSimilarity calculates the cosine similarity between two vectors func cosineSimilarity(a, b []float32) float32 { if len(a) != len(b) { return 0.0 } var dotProduct, normA, normB float32 for i := 0; i < len(a); i++ { dotProduct += a[i] * b[i] normA += a[i] * a[i] normB += b[i] * b[i] } if normA == 0 || normB == 0 { return 0.0 } return dotProduct / (sqrt(normA) * sqrt(normB)) } // sqrt returns the square root of a float32 func sqrt(f float32) float32 { // A simple implementation of square root using Newton's method if f == 0 { return 0 } guess := f / 2 for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision guess = (guess + f/guess) / 2 } return guess }