diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-11-19 12:32:46 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-11-19 12:32:46 +0300 |
| commit | 25b2e2f592bd8df9a5cbd3c77322b572eb8f829c (patch) | |
| tree | a49630b91762e19a28dd500941e0b3f31cc9747c /storage | |
| parent | 88b45f04b73fa408a9d7565c604a59c307bf9652 (diff) | |
Fix: migration use of vec0; rag cleanup
Diffstat (limited to 'storage')
| -rw-r--r-- | storage/migrations/002_add_vector.down.sql | 10 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.up.sql | 32 | ||||
| -rw-r--r-- | storage/storage_test.go | 87 | ||||
| -rw-r--r-- | storage/vector.go | 185 | ||||
| -rw-r--r-- | storage/vector.go.bak | 179 |
5 files changed, 189 insertions, 304 deletions
diff --git a/storage/migrations/002_add_vector.down.sql b/storage/migrations/002_add_vector.down.sql new file mode 100644 index 0000000..71c1f51 --- /dev/null +++ b/storage/migrations/002_add_vector.down.sql @@ -0,0 +1,10 @@ +-- Drop vector storage tables +DROP INDEX IF EXISTS idx_embeddings_384_filename; +DROP INDEX IF EXISTS idx_embeddings_5120_filename; +DROP INDEX IF EXISTS idx_embeddings_384_slug; +DROP INDEX IF EXISTS idx_embeddings_5120_slug; +DROP INDEX IF EXISTS idx_embeddings_384_created_at; +DROP INDEX IF EXISTS idx_embeddings_5120_created_at; + +DROP TABLE IF EXISTS embeddings_384; +DROP TABLE IF EXISTS embeddings_5120;
\ No newline at end of file diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql index 2ac4621..6e164ce 100644 --- a/storage/migrations/002_add_vector.up.sql +++ b/storage/migrations/002_add_vector.up.sql @@ -1,12 +1,26 @@ ---CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_5120 USING vec0( --- embedding FLOAT[5120], --- slug TEXT NOT NULL, --- raw_text TEXT PRIMARY KEY, ---); +-- Create tables for vector storage (replacing vec0 plugin usage) +CREATE TABLE IF NOT EXISTS embeddings_384 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); -CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0( - embedding FLOAT[384], +CREATE TABLE IF NOT EXISTS embeddings_5120 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, slug TEXT NOT NULL, - raw_text TEXT PRIMARY KEY, - filename TEXT NOT NULL DEFAULT '' + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); + +-- Indexes for better performance +CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at); diff --git a/storage/storage_test.go b/storage/storage_test.go index a1c4cf4..a4f2bdd 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,8 +1,8 @@ package storage import ( - "gf-lt/models" "fmt" + "gf-lt/models" "log/slog" "os" "testing" @@ -173,88 +173,3 @@ func TestChatHistory(t *testing.T) { t.Errorf("Expected 0 chats, got %d", len(chats)) } } - -// func TestVecTable(t *testing.T) { -// // healthcheck -// db, err := sqlite3.Open(":memory:") -// if err != nil { -// t.Fatal(err) -// } -// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) -// if err != nil { -// t.Fatal(err) -// } -// stmt.Step() -// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) -// stmt.Close() -// // migration -// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") -// if err != nil { -// t.Fatal(err) -// } -// // data prep and insert -// items := map[int][]float32{ -// 1: {0.1, 0.1, 0.1, 0.1}, -// 2: {0.2, 0.2, 0.2, 0.2}, -// 3: {0.3, 0.3, 0.3, 0.3}, -// 4: {0.4, 0.4, 0.4, 0.4}, -// 5: {0.5, 0.5, 0.5, 0.5}, -// } -// q := []float32{0.4, 0.3, 0.3, 0.3} -// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") -// if err != nil { -// t.Fatal(err) -// } -// for id, values := range items { -// v, err := sqlite_vec.SerializeFloat32(values) -// if err != nil { -// t.Fatal(err) -// } -// stmt.BindInt(1, id) -// stmt.BindBlob(2, v) -// stmt.BindText(3, "some_chat") -// err = stmt.Exec() -// if err != nil { -// t.Fatal(err) -// } -// stmt.Reset() -// } -// stmt.Close() -// // select | vec search -// stmt, _, err = db.Prepare(` -// SELECT -// rowid, -// distance, -// embedding -// FROM vec_items -// WHERE embedding MATCH ? -// ORDER BY distance -// LIMIT 3 -// `) -// if err != nil { -// t.Fatal(err) -// } -// query, err := sqlite_vec.SerializeFloat32(q) -// if err != nil { -// t.Fatal(err) -// } -// stmt.BindBlob(1, query) -// for stmt.Step() { -// rowid := stmt.ColumnInt64(0) -// distance := stmt.ColumnFloat(1) -// emb := stmt.ColumnRawText(2) -// floats := decodeUnsafe(emb) -// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) -// } -// if err := stmt.Err(); err != nil { -// t.Fatal(err) -// } -// err = stmt.Close() -// if err != nil { -// t.Fatal(err) -// } -// err = db.Close() -// if err != nil { -// t.Fatal(err) -// } -// } diff --git a/storage/vector.go b/storage/vector.go index 900803c..73bfe29 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -1,9 +1,9 @@ package storage import ( - "gf-lt/models" "encoding/binary" "fmt" + "gf-lt/models" "unsafe" "github.com/jmoiron/sqlx" @@ -26,7 +26,7 @@ func SerializeVector(vec []float32) []byte { return buf } -// DeserializeVector converts binary blob back to []float32 +// DeserializeVector converts binary blob back to []float32 func DeserializeVector(data []byte) []float32 { count := len(data) / 4 vec := make([]float32, count) @@ -66,50 +66,175 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { if err != nil { return err } - + serializedEmbeddings := SerializeVector(row.Embeddings) - - query := fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) + + query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) - + return err } - - func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { - // TODO: This function has been temporarily disabled to avoid deprecated library usage. - // In the new RAG implementation, this functionality is now in rag_new package. - // For compatibility, return empty result instead of using deprecated vector extension. - return []models.VectorRow{}, nil -} + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } -func (p ProviderSQL) ListFiles() ([]string, error) { - q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) - rows, err := p.db.Query(q) + querySQL := fmt.Sprintf("SELECT embedding, slug, raw_text, filename FROM %s", tableName) + rows, err := p.db.Query(querySQL) if err != nil { return nil, err } defer rows.Close() - - resp := []string{} + + type SearchResult struct { + vector models.VectorRow + distance float32 + } + + var topResults []SearchResult + for rows.Next() { - var filename string - if err := rows.Scan(&filename); err != nil { - return nil, err + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(q, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + + // Add to top results and maintain only top results + topResults = append(topResults, result) + + // Sort and keep only top results + // We'll keep the top 3 closest vectors + if len(topResults) > 3 { + // Simple sort and truncate to maintain only 3 best matches + for i := 0; i < len(topResults); i++ { + for j := i + 1; j < len(topResults); j++ { + if topResults[i].distance > topResults[j].distance { + topResults[i], topResults[j] = topResults[j], topResults[i] + } + } + } + topResults = topResults[:3] } - resp = append(resp, filename) } - - if err := rows.Err(); err != nil { - return nil, err + + // Convert back to VectorRow slice + results := make([]models.VectorRow, len(topResults)) + for i, result := range topResults { + result.vector.Distance = result.distance + results[i] = result.vector } - - return resp, nil + + return results, nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 + } + return guess +} + +func (p ProviderSQL) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query both tables and combine results + for _, table := range []string{vecTableName384, vecTableName5120} { + query := fmt.Sprintf("SELECT DISTINCT filename FROM %s", table) + rows, err := p.db.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } + } + + return allFiles, nil } func (p ProviderSQL) RemoveEmbByFileName(filename string) error { - q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) - _, err := p.db.Exec(q, filename) - return err + var errors []string + + for _, table := range []string{vecTableName384, vecTableName5120} { + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := p.db.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %s", fmt.Sprintf("%v", errors)) + } + + return nil } diff --git a/storage/vector.go.bak b/storage/vector.go.bak deleted file mode 100644 index f663beb..0000000 --- a/storage/vector.go.bak +++ /dev/null @@ -1,179 +0,0 @@ -package storage - -import ( - "gf-lt/models" - "encoding/binary" - "fmt" - "sort" - "unsafe" -) - -type VectorRepo interface { - WriteVector(*models.VectorRow) error - SearchClosest(q []float32) ([]models.VectorRow, error) - ListFiles() ([]string, error) - RemoveEmbByFileName(filename string) error -} - -// SerializeVector converts []float32 to binary blob -func SerializeVector(vec []float32) []byte { - buf := make([]byte, len(vec)*4) // 4 bytes per float32 - for i, v := range vec { - binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) - } - return buf -} - -// DeserializeVector converts binary blob back to []float32 -func DeserializeVector(data []byte) []float32 { - count := len(data) / 4 - vec := make([]float32, count) - for i := 0; i < count; i++ { - vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) - } - return vec -} - -// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 -func mathFloat32bits(f float32) uint32 { - return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) -} - -func mathBitsToFloat32(b uint32) float32 { - return *(*float32)(unsafe.Pointer(&b)) -} - -var ( - vecTableName5120 = "embeddings_5120" - vecTableName384 = "embeddings_384" -) - -func fetchTableName(emb []float32) (string, error) { - switch len(emb) { - case 5120: - return vecTableName5120, nil - case 384: - return vecTableName384, nil - default: - return "", fmt.Errorf("no table for the size of %d", len(emb)) - } -} - -func (p ProviderSQL) WriteVector(row *models.VectorRow) error { - tableName, err := fetchTableName(row.Embeddings) - if err != nil { - return err - } - stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)) - if err != nil { - p.logger.Error("failed to prep a stmt", "error", err) - return err - } - defer stmt.Close() - serializedEmbeddings := SerializeVector(row.Embeddings) - if err := stmt.BindBlob(1, serializedEmbeddings); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(2, row.Slug); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(3, row.RawText); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(4, row.FileName); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - err = stmt.Exec() - if err != nil { - return err - } - return nil -} - -func decodeUnsafe(bs []byte) []float32 { - return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) -} - -func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { - tableName, err := fetchTableName(q) - if err != nil { - return nil, err - } - stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf(`SELECT - distance, - embedding, - slug, - raw_text, - filename - FROM %s - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `, tableName)) - if err != nil { - return nil, err - } - // This function needs to be completely rewritten to use the new binary storage approach - if err != nil { - return nil, err - } - if err := stmt.BindBlob(1, query); err != nil { - p.logger.Error("failed to bind", "error", err) - return nil, err - } - resp := []models.VectorRow{} - for stmt.Step() { - res := models.VectorRow{} - res.Distance = float32(stmt.ColumnFloat(0)) - emb := stmt.ColumnRawText(1) - res.Embeddings = decodeUnsafe(emb) - res.Slug = stmt.ColumnText(2) - res.RawText = stmt.ColumnText(3) - res.FileName = stmt.ColumnText(4) - resp = append(resp, res) - } - if err := stmt.Err(); err != nil { - return nil, err - } - err = stmt.Close() - if err != nil { - return nil, err - } - return resp, nil -} - -func (p ProviderSQL) ListFiles() ([]string, error) { - q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) - stmt, _, err := p.s3Conn.Prepare(q) - if err != nil { - return nil, err - } - defer stmt.Close() - resp := []string{} - for stmt.Step() { - resp = append(resp, stmt.ColumnText(0)) - } - if err := stmt.Err(); err != nil { - return nil, err - } - return resp, nil -} - -func (p ProviderSQL) RemoveEmbByFileName(filename string) error { - q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) - stmt, _, err := p.s3Conn.Prepare(q) - if err != nil { - return err - } - defer stmt.Close() - if err := stmt.BindText(1, filename); err != nil { - return err - } - return stmt.Exec() -} |
