diff options
Diffstat (limited to 'storage')
| -rw-r--r-- | storage/memory.go | 2 | ||||
| -rw-r--r-- | storage/migrate.go | 24 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.down.sql | 34 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.up.sql | 104 | ||||
| -rw-r--r-- | storage/migrations/003_add_fts.down.sql | 2 | ||||
| -rw-r--r-- | storage/migrations/003_add_fts.up.sql | 15 | ||||
| -rw-r--r-- | storage/migrations/004_populate_fts.down.sql | 2 | ||||
| -rw-r--r-- | storage/migrations/004_populate_fts.up.sql | 4 | ||||
| -rw-r--r-- | storage/migrations/005_drop_unused_embeddings.down.sql | 87 | ||||
| -rw-r--r-- | storage/migrations/005_drop_unused_embeddings.up.sql | 32 | ||||
| -rw-r--r-- | storage/storage.go | 33 | ||||
| -rw-r--r-- | storage/storage_test.go | 90 | ||||
| -rw-r--r-- | storage/vector.go | 241 |
13 files changed, 445 insertions, 225 deletions
diff --git a/storage/memory.go b/storage/memory.go index c9fc853..406182f 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -1,6 +1,6 @@ package storage -import "elefant/models" +import "gf-lt/models" type Memories interface { Memorise(m *models.Memory) (*models.Memory, error) diff --git a/storage/migrate.go b/storage/migrate.go index b05dddc..b6fed37 100644 --- a/storage/migrate.go +++ b/storage/migrate.go @@ -5,35 +5,47 @@ import ( "fmt" "io/fs" "strings" - - _ "github.com/asg017/sqlite-vec-go-bindings/ncruces" ) //go:embed migrations/* var migrationsFS embed.FS -func (p *ProviderSQL) Migrate() { +func (p *ProviderSQL) Migrate() error { // Get the embedded filesystem migrationsDir, err := fs.Sub(migrationsFS, "migrations") if err != nil { p.logger.Error("Failed to get embedded migrations directory;", "error", err) + return fmt.Errorf("failed to get embedded migrations directory: %w", err) } // List all .up.sql files files, err := migrationsFS.ReadDir("migrations") if err != nil { p.logger.Error("Failed to read migrations directory;", "error", err) + return fmt.Errorf("failed to read migrations directory: %w", err) } + + // Check if FTS already has data - skip populate migration if so + var ftsCount int + _ = p.db.QueryRow("SELECT COUNT(*) FROM fts_embeddings").Scan(&ftsCount) + skipFTSMigration := ftsCount > 0 + // Execute each .up.sql file for _, file := range files { if strings.HasSuffix(file.Name(), ".up.sql") { + // Skip FTS populate migration if already populated + if skipFTSMigration && strings.Contains(file.Name(), "004_populate_fts") { + p.logger.Debug("Skipping FTS migration - already populated", "file", file.Name()) + continue + } err := p.executeMigration(migrationsDir, file.Name()) if err != nil { p.logger.Error("Failed to execute migration %s: %v", file.Name(), err) - panic(err) + return fmt.Errorf("failed to execute migration %s: %w", file.Name(), err) } } } p.logger.Debug("All migrations executed successfully!") + return nil } func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error { @@ -53,8 +65,8 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err } func (p *ProviderSQL) executeSQL(sqlContent []byte) error { - // Connect to the database (example using a simple connection) - err := p.s3Conn.Exec(string(sqlContent)) + // Execute the migration content using standard database connection + _, err := p.db.Exec(string(sqlContent)) if err != nil { return fmt.Errorf("failed to execute SQL: %w", err) } diff --git a/storage/migrations/002_add_vector.down.sql b/storage/migrations/002_add_vector.down.sql new file mode 100644 index 0000000..a257b11 --- /dev/null +++ b/storage/migrations/002_add_vector.down.sql @@ -0,0 +1,34 @@ +-- Drop vector storage tables +DROP INDEX IF EXISTS idx_embeddings_384_filename; +DROP INDEX IF EXISTS idx_embeddings_768_filename; +DROP INDEX IF EXISTS idx_embeddings_1024_filename; +DROP INDEX IF EXISTS idx_embeddings_1536_filename; +DROP INDEX IF EXISTS idx_embeddings_2048_filename; +DROP INDEX IF EXISTS idx_embeddings_3072_filename; +DROP INDEX IF EXISTS idx_embeddings_4096_filename; +DROP INDEX IF EXISTS idx_embeddings_5120_filename; +DROP INDEX IF EXISTS idx_embeddings_384_slug; +DROP INDEX IF EXISTS idx_embeddings_768_slug; +DROP INDEX IF EXISTS idx_embeddings_1024_slug; +DROP INDEX IF EXISTS idx_embeddings_1536_slug; +DROP INDEX IF EXISTS idx_embeddings_2048_slug; +DROP INDEX IF EXISTS idx_embeddings_3072_slug; +DROP INDEX IF EXISTS idx_embeddings_4096_slug; +DROP INDEX IF EXISTS idx_embeddings_5120_slug; +DROP INDEX IF EXISTS idx_embeddings_384_created_at; +DROP INDEX IF EXISTS idx_embeddings_768_created_at; +DROP INDEX IF EXISTS idx_embeddings_1024_created_at; +DROP INDEX IF EXISTS idx_embeddings_1536_created_at; +DROP INDEX IF EXISTS idx_embeddings_2048_created_at; +DROP INDEX IF EXISTS idx_embeddings_3072_created_at; +DROP INDEX IF EXISTS idx_embeddings_4096_created_at; +DROP INDEX IF EXISTS idx_embeddings_5120_created_at; + +DROP TABLE IF EXISTS embeddings_384; +DROP TABLE IF EXISTS embeddings_768; +DROP TABLE IF EXISTS embeddings_1024; +DROP TABLE IF EXISTS embeddings_1536; +DROP TABLE IF EXISTS embeddings_2048; +DROP TABLE IF EXISTS embeddings_3072; +DROP TABLE IF EXISTS embeddings_4096; +DROP TABLE IF EXISTS embeddings_5120;
\ No newline at end of file diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql index 2ac4621..baf703d 100644 --- a/storage/migrations/002_add_vector.up.sql +++ b/storage/migrations/002_add_vector.up.sql @@ -1,12 +1,98 @@ ---CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_5120 USING vec0( --- embedding FLOAT[5120], --- slug TEXT NOT NULL, --- raw_text TEXT PRIMARY KEY, ---); +-- Create tables for vector storage (replacing vec0 plugin usage) +CREATE TABLE IF NOT EXISTS embeddings_384 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_768 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1024 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1536 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); -CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0( - embedding FLOAT[384], +CREATE TABLE IF NOT EXISTS embeddings_2048 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, slug TEXT NOT NULL, - raw_text TEXT PRIMARY KEY, - filename TEXT NOT NULL DEFAULT '' + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); + +CREATE TABLE IF NOT EXISTS embeddings_3072 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_4096 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_5120 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Indexes for better performance +CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_filename ON embeddings_768(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_filename ON embeddings_1024(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_filename ON embeddings_1536(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_filename ON embeddings_2048(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_filename ON embeddings_3072(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_filename ON embeddings_4096(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_slug ON embeddings_768(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_slug ON embeddings_1024(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_slug ON embeddings_1536(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_slug ON embeddings_2048(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_slug ON embeddings_3072(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_slug ON embeddings_4096(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_created_at ON embeddings_768(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_created_at ON embeddings_1024(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_created_at ON embeddings_1536(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_created_at ON embeddings_2048(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_created_at ON embeddings_3072(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_created_at ON embeddings_4096(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at); diff --git a/storage/migrations/003_add_fts.down.sql b/storage/migrations/003_add_fts.down.sql new file mode 100644 index 0000000..e565fd5 --- /dev/null +++ b/storage/migrations/003_add_fts.down.sql @@ -0,0 +1,2 @@ +-- Drop FTS5 virtual table +DROP TABLE IF EXISTS fts_embeddings;
\ No newline at end of file diff --git a/storage/migrations/003_add_fts.up.sql b/storage/migrations/003_add_fts.up.sql new file mode 100644 index 0000000..114586a --- /dev/null +++ b/storage/migrations/003_add_fts.up.sql @@ -0,0 +1,15 @@ +-- Create FTS5 virtual table for full-text search +CREATE VIRTUAL TABLE IF NOT EXISTS fts_embeddings USING fts5( + slug UNINDEXED, + raw_text, + filename UNINDEXED, + embedding_size UNINDEXED, + tokenize='porter unicode61' -- Use porter stemmer and unicode61 tokenizer +); + +-- Create triggers to maintain FTS table when embeddings are inserted/deleted +-- Note: We'll handle inserts/deletes programmatically for simplicity +-- but triggers could be added here if needed. + +-- Indexes for performance (FTS5 manages its own indexes) +-- No additional indexes needed for FTS5 virtual table.
\ No newline at end of file diff --git a/storage/migrations/004_populate_fts.down.sql b/storage/migrations/004_populate_fts.down.sql new file mode 100644 index 0000000..2b5c756 --- /dev/null +++ b/storage/migrations/004_populate_fts.down.sql @@ -0,0 +1,2 @@ +-- Clear FTS table (optional) +DELETE FROM fts_embeddings;
\ No newline at end of file diff --git a/storage/migrations/004_populate_fts.up.sql b/storage/migrations/004_populate_fts.up.sql new file mode 100644 index 0000000..1068bf7 --- /dev/null +++ b/storage/migrations/004_populate_fts.up.sql @@ -0,0 +1,4 @@ +-- Populate FTS table with existing embeddings (incremental - only inserts missing rows) +-- Only use 768 embeddings as that's what we use +INSERT OR IGNORE INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 768 FROM embeddings_768;
\ No newline at end of file diff --git a/storage/migrations/005_drop_unused_embeddings.down.sql b/storage/migrations/005_drop_unused_embeddings.down.sql new file mode 100644 index 0000000..063cb88 --- /dev/null +++ b/storage/migrations/005_drop_unused_embeddings.down.sql @@ -0,0 +1,87 @@ +-- Recreate unused embedding tables (for rollback) +CREATE TABLE IF NOT EXISTS embeddings_384 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1024 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1536 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_2048 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_3072 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_4096 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_5120 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_filename ON embeddings_1024(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_filename ON embeddings_1536(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_filename ON embeddings_2048(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_filename ON embeddings_3072(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_filename ON embeddings_4096(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename); + +CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_slug ON embeddings_1024(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_slug ON embeddings_1536(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_slug ON embeddings_2048(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_slug ON embeddings_3072(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_slug ON embeddings_4096(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug); + +CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_created_at ON embeddings_1024(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_created_at ON embeddings_1536(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_created_at ON embeddings_2048(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_created_at ON embeddings_3072(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_created_at ON embeddings_4096(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at); diff --git a/storage/migrations/005_drop_unused_embeddings.up.sql b/storage/migrations/005_drop_unused_embeddings.up.sql new file mode 100644 index 0000000..f26e30f --- /dev/null +++ b/storage/migrations/005_drop_unused_embeddings.up.sql @@ -0,0 +1,32 @@ +-- Drop unused embedding tables (we only use 768) +DROP INDEX IF EXISTS idx_embeddings_384_filename; +DROP INDEX IF EXISTS idx_embeddings_1024_filename; +DROP INDEX IF EXISTS idx_embeddings_1536_filename; +DROP INDEX IF EXISTS idx_embeddings_2048_filename; +DROP INDEX IF EXISTS idx_embeddings_3072_filename; +DROP INDEX IF EXISTS idx_embeddings_4096_filename; +DROP INDEX IF EXISTS idx_embeddings_5120_filename; + +DROP INDEX IF EXISTS idx_embeddings_384_slug; +DROP INDEX IF EXISTS idx_embeddings_1024_slug; +DROP INDEX IF EXISTS idx_embeddings_1536_slug; +DROP INDEX IF EXISTS idx_embeddings_2048_slug; +DROP INDEX IF EXISTS idx_embeddings_3072_slug; +DROP INDEX IF EXISTS idx_embeddings_4096_slug; +DROP INDEX IF EXISTS idx_embeddings_5120_slug; + +DROP INDEX IF EXISTS idx_embeddings_384_created_at; +DROP INDEX IF EXISTS idx_embeddings_1024_created_at; +DROP INDEX IF EXISTS idx_embeddings_1536_created_at; +DROP INDEX IF EXISTS idx_embeddings_2048_created_at; +DROP INDEX IF EXISTS idx_embeddings_3072_created_at; +DROP INDEX IF EXISTS idx_embeddings_4096_created_at; +DROP INDEX IF EXISTS idx_embeddings_5120_created_at; + +DROP TABLE IF EXISTS embeddings_384; +DROP TABLE IF EXISTS embeddings_1024; +DROP TABLE IF EXISTS embeddings_1536; +DROP TABLE IF EXISTS embeddings_2048; +DROP TABLE IF EXISTS embeddings_3072; +DROP TABLE IF EXISTS embeddings_4096; +DROP TABLE IF EXISTS embeddings_5120; diff --git a/storage/storage.go b/storage/storage.go index f759700..57631da 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,12 +1,11 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" - "github.com/ncruces/go-sqlite3" ) type FullRepo interface { @@ -28,7 +27,6 @@ type ChatHistory interface { type ProviderSQL struct { db *sqlx.DB - s3Conn *sqlite3.Conn logger *slog.Logger } @@ -97,19 +95,38 @@ func (p ProviderSQL) ChatGetMaxID() (uint32, error) { return id, err } -// opens two connections +// opens database connection func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { logger.Error("failed to open db connection", "error", err) return nil } + // Enable WAL mode for better concurrency and performance + if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil { + logger.Warn("failed to enable WAL mode", "error", err) + } + if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil { + logger.Warn("failed to set synchronous mode", "error", err) + } + // Increase cache size for better performance + if _, err := db.Exec("PRAGMA cache_size = -2000;"); err != nil { + logger.Warn("failed to set cache size", "error", err) + } + // Log actual journal mode for debugging + var journalMode string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&journalMode); err == nil { + logger.Debug("SQLite journal mode", "mode", journalMode) + } p := ProviderSQL{db: db, logger: logger} - p.s3Conn, err = sqlite3.Open(dbPath) - if err != nil { - logger.Error("failed to open vecdb connection", "error", err) + if err := p.Migrate(); err != nil { + logger.Error("migration failed, app cannot start", "error", err) return nil } - p.Migrate() return p } + +// DB returns the underlying database connection +func (p ProviderSQL) DB() *sqlx.DB { + return p.db +} diff --git a/storage/storage_test.go b/storage/storage_test.go index ff3b5e6..a4f2bdd 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,18 +1,15 @@ package storage import ( - "elefant/models" "fmt" - "log" + "gf-lt/models" "log/slog" "os" "testing" "time" - sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" - "github.com/ncruces/go-sqlite3" ) func TestMemories(t *testing.T) { @@ -176,88 +173,3 @@ func TestChatHistory(t *testing.T) { t.Errorf("Expected 0 chats, got %d", len(chats)) } } - -func TestVecTable(t *testing.T) { - // healthcheck - db, err := sqlite3.Open(":memory:") - if err != nil { - t.Fatal(err) - } - stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) - if err != nil { - t.Fatal(err) - } - stmt.Step() - log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) - stmt.Close() - // migration - err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") - if err != nil { - t.Fatal(err) - } - // data prep and insert - items := map[int][]float32{ - 1: {0.1, 0.1, 0.1, 0.1}, - 2: {0.2, 0.2, 0.2, 0.2}, - 3: {0.3, 0.3, 0.3, 0.3}, - 4: {0.4, 0.4, 0.4, 0.4}, - 5: {0.5, 0.5, 0.5, 0.5}, - } - q := []float32{0.28, 0.3, 0.3, 0.3} - stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") - if err != nil { - t.Fatal(err) - } - for id, values := range items { - v, err := sqlite_vec.SerializeFloat32(values) - if err != nil { - t.Fatal(err) - } - stmt.BindInt(1, id) - stmt.BindBlob(2, v) - stmt.BindText(3, "some_chat") - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - stmt.Reset() - } - stmt.Close() - // select | vec search - stmt, _, err = db.Prepare(` - SELECT - rowid, - distance, - embedding - FROM vec_items - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `) - if err != nil { - t.Fatal(err) - } - query, err := sqlite_vec.SerializeFloat32(q) - if err != nil { - t.Fatal(err) - } - stmt.BindBlob(1, query) - for stmt.Step() { - rowid := stmt.ColumnInt64(0) - distance := stmt.ColumnFloat(1) - emb := stmt.ColumnRawText(2) - floats := decodeUnsafe(emb) - log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) - } - if err := stmt.Err(); err != nil { - t.Fatal(err) - } - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - err = db.Close() - if err != nil { - t.Fatal(err) - } -} diff --git a/storage/vector.go b/storage/vector.go index 5e9069c..fed78a9 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -1,32 +1,55 @@ package storage import ( - "elefant/models" - "errors" + "encoding/binary" "fmt" + "gf-lt/models" + "sort" "unsafe" - sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" + "github.com/jmoiron/sqlx" ) type VectorRepo interface { WriteVector(*models.VectorRow) error - SearchClosest(q []float32) ([]models.VectorRow, error) + SearchClosest(q []float32, limit int) ([]models.VectorRow, error) ListFiles() ([]string, error) RemoveEmbByFileName(filename string) error + DB() *sqlx.DB } -var ( - vecTableName5120 = "embeddings_5120" - vecTableName384 = "embeddings_384" -) +// SerializeVector converts []float32 to binary blob +func SerializeVector(vec []float32) []byte { + buf := make([]byte, len(vec)*4) // 4 bytes per float32 + for i, v := range vec { + binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) + } + return buf +} + +// DeserializeVector converts binary blob back to []float32 +func DeserializeVector(data []byte) []float32 { + count := len(data) / 4 + vec := make([]float32, count) + for i := 0; i < count; i++ { + vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) + } + return vec +} + +// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 +func mathFloat32bits(f float32) uint32 { + return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) +} + +func mathBitsToFloat32(b uint32) float32 { + return *(*float32)(unsafe.Pointer(&b)) +} func fetchTableName(emb []float32) (string, error) { switch len(emb) { - case 5120: - return vecTableName5120, nil - case 384: - return vecTableName384, nil + case 768: + return "embeddings_768", nil default: return "", fmt.Errorf("no table for the size of %d", len(emb)) } @@ -37,127 +60,121 @@ func (p ProviderSQL) WriteVector(row *models.VectorRow) error { if err != nil { return err } - stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)) - if err != nil { - p.logger.Error("failed to prep a stmt", "error", err) - return err - } - defer stmt.Close() - v, err := sqlite_vec.SerializeFloat32(row.Embeddings) - if err != nil { - p.logger.Error("failed to serialize vector", - "emb-len", len(row.Embeddings), "error", err) - return err - } - if v == nil { - err = errors.New("empty vector after serialization") - p.logger.Error("empty vector after serialization", - "emb-len", len(row.Embeddings), "text", row.RawText, "error", err) - return err - } - if err := stmt.BindBlob(1, v); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(2, row.Slug); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(3, row.RawText); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - if err := stmt.BindText(4, row.FileName); err != nil { - p.logger.Error("failed to bind", "error", err) - return err - } - err = stmt.Exec() - if err != nil { - return err - } - return nil -} - -func decodeUnsafe(bs []byte) []float32 { - return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) + serializedEmbeddings := SerializeVector(row.Embeddings) + query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) + _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) + return err } -func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { +func (p ProviderSQL) SearchClosest(q []float32, limit int) ([]models.VectorRow, error) { tableName, err := fetchTableName(q) if err != nil { return nil, err } - stmt, _, err := p.s3Conn.Prepare( - fmt.Sprintf(`SELECT - distance, - embedding, - slug, - raw_text, - filename - FROM %s - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `, tableName)) + querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName + rows, err := p.db.Query(querySQL) if err != nil { return nil, err } - query, err := sqlite_vec.SerializeFloat32(q[:]) - if err != nil { - return nil, err + defer rows.Close() + type SearchResult struct { + vector models.VectorRow + distance float32 + } + var allResults []SearchResult + for rows.Next() { + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(q, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + allResults = append(allResults, result) + } + // Sort by distance + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].distance < allResults[j].distance + }) + // Truncate to limit + if len(allResults) > limit { + allResults = allResults[:limit] + } + // Convert back to VectorRow slice + results := make([]models.VectorRow, len(allResults)) + for i, result := range allResults { + result.vector.Distance = result.distance + results[i] = result.vector + } + return results, nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 } - if err := stmt.BindBlob(1, query); err != nil { - p.logger.Error("failed to bind", "error", err) - return nil, err + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] } - resp := []models.VectorRow{} - for stmt.Step() { - res := models.VectorRow{} - res.Distance = float32(stmt.ColumnFloat(0)) - emb := stmt.ColumnRawText(1) - res.Embeddings = decodeUnsafe(emb) - res.Slug = stmt.ColumnText(2) - res.RawText = stmt.ColumnText(3) - res.FileName = stmt.ColumnText(4) - resp = append(resp, res) - } - if err := stmt.Err(); err != nil { - return nil, err + if normA == 0 || normB == 0 { + return 0.0 } - err = stmt.Close() - if err != nil { - return nil, err + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 } - return resp, nil + return guess } func (p ProviderSQL) ListFiles() ([]string, error) { - q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) - stmt, _, err := p.s3Conn.Prepare(q) + query := "SELECT DISTINCT filename FROM embeddings_768" + rows, err := p.db.Query(query) if err != nil { return nil, err } - defer stmt.Close() - resp := []string{} - for stmt.Step() { - resp = append(resp, stmt.ColumnText(0)) - } - if err := stmt.Err(); err != nil { - return nil, err + defer rows.Close() + var allFiles []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + allFiles = append(allFiles, filename) } - return resp, nil + return allFiles, nil } func (p ProviderSQL) RemoveEmbByFileName(filename string) error { - q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) - stmt, _, err := p.s3Conn.Prepare(q) - if err != nil { - return err - } - defer stmt.Close() - if err := stmt.BindText(1, filename); err != nil { - return err - } - return stmt.Exec() + query := "DELETE FROM embeddings_768 WHERE filename = ?" + _, err := p.db.Exec(query, filename) + return err } |
