diff options
Diffstat (limited to 'storage')
| -rw-r--r-- | storage/memory.go | 14 | ||||
| -rw-r--r-- | storage/migrate.go | 10 | ||||
| -rw-r--r-- | storage/migrations/001_init.up.sql | 1 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.down.sql | 34 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.up.sql | 98 | ||||
| -rw-r--r-- | storage/migrations/003_add_fts.down.sql | 2 | ||||
| -rw-r--r-- | storage/migrations/003_add_fts.up.sql | 15 | ||||
| -rw-r--r-- | storage/migrations/004_populate_fts.down.sql | 2 | ||||
| -rw-r--r-- | storage/migrations/004_populate_fts.up.sql | 26 | ||||
| -rw-r--r-- | storage/storage.go | 73 | ||||
| -rw-r--r-- | storage/storage_test.go | 34 | ||||
| -rw-r--r-- | storage/vector.go | 231 |
12 files changed, 518 insertions, 22 deletions
diff --git a/storage/memory.go b/storage/memory.go index a7bf8cc..406182f 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -1,6 +1,6 @@ package storage -import "elefant/models" +import "gf-lt/models" type Memories interface { Memorise(m *models.Memory) (*models.Memory, error) @@ -9,15 +9,23 @@ type Memories interface { } func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { - query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;" + query := ` + INSERT INTO memories (agent, topic, mind) + VALUES (:agent, :topic, :mind) + ON CONFLICT (agent, topic) DO UPDATE + SET mind = excluded.mind, + updated_at = CURRENT_TIMESTAMP + RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { + p.logger.Error("failed to prepare stmt", "query", query, "error", err) return nil, err } defer stmt.Close() var memory models.Memory err = stmt.Get(&memory, m) if err != nil { + p.logger.Error("failed to upsert memory", "query", query, "error", err) return nil, err } return &memory, nil @@ -28,6 +36,7 @@ func (p ProviderSQL) Recall(agent, topic string) (string, error) { var mind string err := p.db.Get(&mind, query, agent, topic) if err != nil { + p.logger.Error("failed to get memory", "query", query, "error", err) return "", err } return mind, nil @@ -38,6 +47,7 @@ func (p ProviderSQL) RecallTopics(agent string) ([]string, error) { var topics []string err := p.db.Select(&topics, query, agent) if err != nil { + p.logger.Error("failed to get topics", "query", query, "error", err) return nil, err } return topics, nil diff --git a/storage/migrate.go b/storage/migrate.go index d97b99d..38f9854 100644 --- a/storage/migrate.go +++ b/storage/migrate.go @@ -10,16 +10,18 @@ import ( //go:embed migrations/* var migrationsFS embed.FS -func (p *ProviderSQL) Migrate() { +func (p *ProviderSQL) Migrate() error { // Get the embedded filesystem migrationsDir, err := fs.Sub(migrationsFS, "migrations") if err != nil { p.logger.Error("Failed to get embedded migrations directory;", "error", err) + return fmt.Errorf("failed to get embedded migrations directory: %w", err) } // List all .up.sql files files, err := migrationsFS.ReadDir("migrations") if err != nil { p.logger.Error("Failed to read migrations directory;", "error", err) + return fmt.Errorf("failed to read migrations directory: %w", err) } // Execute each .up.sql file for _, file := range files { @@ -27,10 +29,12 @@ func (p *ProviderSQL) Migrate() { err := p.executeMigration(migrationsDir, file.Name()) if err != nil { p.logger.Error("Failed to execute migration %s: %v", file.Name(), err) + return fmt.Errorf("failed to execute migration %s: %w", file.Name(), err) } } } - p.logger.Info("All migrations executed successfully!") + p.logger.Debug("All migrations executed successfully!") + return nil } func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error { @@ -50,7 +54,7 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err } func (p *ProviderSQL) executeSQL(sqlContent []byte) error { - // Connect to the database (example using a simple connection) + // Execute the migration content using standard database connection _, err := p.db.Exec(string(sqlContent)) if err != nil { return fmt.Errorf("failed to execute SQL: %w", err) diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql index 8980ccf..09bb5e6 100644 --- a/storage/migrations/001_init.up.sql +++ b/storage/migrations/001_init.up.sql @@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL DEFAULT 'assistant', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); diff --git a/storage/migrations/002_add_vector.down.sql b/storage/migrations/002_add_vector.down.sql new file mode 100644 index 0000000..a257b11 --- /dev/null +++ b/storage/migrations/002_add_vector.down.sql @@ -0,0 +1,34 @@ +-- Drop vector storage tables +DROP INDEX IF EXISTS idx_embeddings_384_filename; +DROP INDEX IF EXISTS idx_embeddings_768_filename; +DROP INDEX IF EXISTS idx_embeddings_1024_filename; +DROP INDEX IF EXISTS idx_embeddings_1536_filename; +DROP INDEX IF EXISTS idx_embeddings_2048_filename; +DROP INDEX IF EXISTS idx_embeddings_3072_filename; +DROP INDEX IF EXISTS idx_embeddings_4096_filename; +DROP INDEX IF EXISTS idx_embeddings_5120_filename; +DROP INDEX IF EXISTS idx_embeddings_384_slug; +DROP INDEX IF EXISTS idx_embeddings_768_slug; +DROP INDEX IF EXISTS idx_embeddings_1024_slug; +DROP INDEX IF EXISTS idx_embeddings_1536_slug; +DROP INDEX IF EXISTS idx_embeddings_2048_slug; +DROP INDEX IF EXISTS idx_embeddings_3072_slug; +DROP INDEX IF EXISTS idx_embeddings_4096_slug; +DROP INDEX IF EXISTS idx_embeddings_5120_slug; +DROP INDEX IF EXISTS idx_embeddings_384_created_at; +DROP INDEX IF EXISTS idx_embeddings_768_created_at; +DROP INDEX IF EXISTS idx_embeddings_1024_created_at; +DROP INDEX IF EXISTS idx_embeddings_1536_created_at; +DROP INDEX IF EXISTS idx_embeddings_2048_created_at; +DROP INDEX IF EXISTS idx_embeddings_3072_created_at; +DROP INDEX IF EXISTS idx_embeddings_4096_created_at; +DROP INDEX IF EXISTS idx_embeddings_5120_created_at; + +DROP TABLE IF EXISTS embeddings_384; +DROP TABLE IF EXISTS embeddings_768; +DROP TABLE IF EXISTS embeddings_1024; +DROP TABLE IF EXISTS embeddings_1536; +DROP TABLE IF EXISTS embeddings_2048; +DROP TABLE IF EXISTS embeddings_3072; +DROP TABLE IF EXISTS embeddings_4096; +DROP TABLE IF EXISTS embeddings_5120;
\ No newline at end of file diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql new file mode 100644 index 0000000..baf703d --- /dev/null +++ b/storage/migrations/002_add_vector.up.sql @@ -0,0 +1,98 @@ +-- Create tables for vector storage (replacing vec0 plugin usage) +CREATE TABLE IF NOT EXISTS embeddings_384 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_768 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1024 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1536 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_2048 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_3072 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_4096 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_5120 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Indexes for better performance +CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_filename ON embeddings_768(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_filename ON embeddings_1024(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_filename ON embeddings_1536(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_filename ON embeddings_2048(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_filename ON embeddings_3072(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_filename ON embeddings_4096(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_slug ON embeddings_768(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_slug ON embeddings_1024(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_slug ON embeddings_1536(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_slug ON embeddings_2048(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_slug ON embeddings_3072(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_slug ON embeddings_4096(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_created_at ON embeddings_768(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_created_at ON embeddings_1024(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_created_at ON embeddings_1536(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_created_at ON embeddings_2048(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_created_at ON embeddings_3072(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_created_at ON embeddings_4096(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at); diff --git a/storage/migrations/003_add_fts.down.sql b/storage/migrations/003_add_fts.down.sql new file mode 100644 index 0000000..e565fd5 --- /dev/null +++ b/storage/migrations/003_add_fts.down.sql @@ -0,0 +1,2 @@ +-- Drop FTS5 virtual table +DROP TABLE IF EXISTS fts_embeddings;
\ No newline at end of file diff --git a/storage/migrations/003_add_fts.up.sql b/storage/migrations/003_add_fts.up.sql new file mode 100644 index 0000000..114586a --- /dev/null +++ b/storage/migrations/003_add_fts.up.sql @@ -0,0 +1,15 @@ +-- Create FTS5 virtual table for full-text search +CREATE VIRTUAL TABLE IF NOT EXISTS fts_embeddings USING fts5( + slug UNINDEXED, + raw_text, + filename UNINDEXED, + embedding_size UNINDEXED, + tokenize='porter unicode61' -- Use porter stemmer and unicode61 tokenizer +); + +-- Create triggers to maintain FTS table when embeddings are inserted/deleted +-- Note: We'll handle inserts/deletes programmatically for simplicity +-- but triggers could be added here if needed. + +-- Indexes for performance (FTS5 manages its own indexes) +-- No additional indexes needed for FTS5 virtual table.
\ No newline at end of file diff --git a/storage/migrations/004_populate_fts.down.sql b/storage/migrations/004_populate_fts.down.sql new file mode 100644 index 0000000..2b5c756 --- /dev/null +++ b/storage/migrations/004_populate_fts.down.sql @@ -0,0 +1,2 @@ +-- Clear FTS table (optional) +DELETE FROM fts_embeddings;
\ No newline at end of file diff --git a/storage/migrations/004_populate_fts.up.sql b/storage/migrations/004_populate_fts.up.sql new file mode 100644 index 0000000..1d1b16a --- /dev/null +++ b/storage/migrations/004_populate_fts.up.sql @@ -0,0 +1,26 @@ +-- Populate FTS table with existing embeddings +DELETE FROM fts_embeddings; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 384 FROM embeddings_384; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 768 FROM embeddings_768; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 1024 FROM embeddings_1024; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 1536 FROM embeddings_1536; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 2048 FROM embeddings_2048; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 3072 FROM embeddings_3072; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 4096 FROM embeddings_4096; + +INSERT INTO fts_embeddings (slug, raw_text, filename, embedding_size) +SELECT slug, raw_text, filename, 5120 FROM embeddings_5120;
\ No newline at end of file diff --git a/storage/storage.go b/storage/storage.go index 67b8dd8..57631da 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,19 +1,28 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" ) +type FullRepo interface { + ChatHistory + Memories + VectorRepo +} + type ChatHistory interface { ListChats() ([]models.Chat, error) GetChatByID(id uint32) (*models.Chat, error) + GetChatByChar(char string) ([]models.Chat, error) GetLastChat() (*models.Chat, error) + GetLastChatByAgent(agent string) (*models.Chat, error) UpsertChat(chat *models.Chat) (*models.Chat, error) RemoveChat(id uint32) error + ChatGetMaxID() (uint32, error) } type ProviderSQL struct { @@ -27,6 +36,12 @@ func (p ProviderSQL) ListChats() ([]models.Chat, error) { return resp, err } +func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) { + resp := []models.Chat{} + err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char) + return resp, err +} + func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { resp := models.Chat{} err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id) @@ -39,16 +54,28 @@ func (p ProviderSQL) GetLastChat() (*models.Chat, error) { return &resp, err } +func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) { + resp := models.Chat{} + query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1" + err := p.db.Get(&resp, query, agent) + return &resp, err +} + +// https://sqlite.org/lang_upsert.html +// on conflict was added func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { // Prepare the SQL statement query := ` - INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at) - VALUES (:id, :name, :msgs, :created_at, :updated_at) + INSERT INTO chats (id, name, msgs, agent, created_at, updated_at) + VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at) + ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs, + updated_at=excluded.updated_at RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { return nil, err } + defer stmt.Close() // Execute the query and scan the result into a new chat object var resp models.Chat err = stmt.Get(&resp, chat) @@ -61,13 +88,45 @@ func (p ProviderSQL) RemoveChat(id uint32) error { return err } -func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory { +func (p ProviderSQL) ChatGetMaxID() (uint32, error) { + query := "SELECT MAX(id) FROM chats;" + var id uint32 + err := p.db.Get(&id, query) + return id, err +} + +// opens database connection +func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { - panic(err) + logger.Error("failed to open db connection", "error", err) + return nil + } + // Enable WAL mode for better concurrency and performance + if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil { + logger.Warn("failed to enable WAL mode", "error", err) + } + if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil { + logger.Warn("failed to set synchronous mode", "error", err) + } + // Increase cache size for better performance + if _, err := db.Exec("PRAGMA cache_size = -2000;"); err != nil { + logger.Warn("failed to set cache size", "error", err) + } + // Log actual journal mode for debugging + var journalMode string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&journalMode); err == nil { + logger.Debug("SQLite journal mode", "mode", journalMode) } - // get SQLite version p := ProviderSQL{db: db, logger: logger} - p.Migrate() + if err := p.Migrate(); err != nil { + logger.Error("migration failed, app cannot start", "error", err) + return nil + } return p } + +// DB returns the underlying database connection +func (p ProviderSQL) DB() *sqlx.DB { + return p.db +} diff --git a/storage/storage_test.go b/storage/storage_test.go index ad1f1bf..a4f2bdd 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,8 +1,8 @@ package storage import ( - "elefant/models" "fmt" + "gf-lt/models" "log/slog" "os" "testing" @@ -35,22 +35,27 @@ CREATE TABLE IF NOT EXISTS memories ( logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), } // Create a sample memory for testing - sampleMemory := &models.Memory{ + sampleMemory := models.Memory{ Agent: "testAgent", Topic: "testTopic", Mind: "testMind", CreatedAt: time.Now(), UpdatedAt: time.Now(), } + sampleMemoryRewrite := models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "same topic, new mind", + } cases := []struct { - memory *models.Memory + memories []models.Memory }{ - {memory: sampleMemory}, + {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}}, } for i, tc := range cases { t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { // Recall topics: get no rows - topics, err := provider.RecallTopics(tc.memory.Agent) + topics, err := provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -58,12 +63,12 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected no topics, got: %v", topics) } // Memorise - _, err = provider.Memorise(tc.memory) + _, err = provider.Memorise(&tc.memories[0]) if err != nil { t.Fatalf("Failed to memorise: %v", err) } // Recall topics: has topics - topics, err = provider.RecallTopics(tc.memory.Agent) + topics, err = provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -71,12 +76,20 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected topics, got none") } // Recall - content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic) if err != nil { t.Fatalf("Failed to recall: %v", err) } - if content != tc.memory.Mind { - t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + if content != tc.memories[0].Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content) + } + // rewrite mind of same agent-topic + newMem, err := provider.Memorise(&tc.memories[1]) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + if newMem.Mind == tc.memories[0].Mind { + t.Fatalf("Failed to change mind: %v", newMem.Mind) } }) } @@ -95,6 +108,7 @@ func TestChatHistory(t *testing.T) { id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`) diff --git a/storage/vector.go b/storage/vector.go new file mode 100644 index 0000000..e3bbb89 --- /dev/null +++ b/storage/vector.go @@ -0,0 +1,231 @@ +package storage + +import ( + "encoding/binary" + "fmt" + "gf-lt/models" + "sort" + "unsafe" + + "github.com/jmoiron/sqlx" +) + +type VectorRepo interface { + WriteVector(*models.VectorRow) error + SearchClosest(q []float32, limit int) ([]models.VectorRow, error) + ListFiles() ([]string, error) + RemoveEmbByFileName(filename string) error + DB() *sqlx.DB +} + +// SerializeVector converts []float32 to binary blob +func SerializeVector(vec []float32) []byte { + buf := make([]byte, len(vec)*4) // 4 bytes per float32 + for i, v := range vec { + binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) + } + return buf +} + +// DeserializeVector converts binary blob back to []float32 +func DeserializeVector(data []byte) []float32 { + count := len(data) / 4 + vec := make([]float32, count) + for i := 0; i < count; i++ { + vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) + } + return vec +} + +// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 +func mathFloat32bits(f float32) uint32 { + return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) +} + +func mathBitsToFloat32(b uint32) float32 { + return *(*float32)(unsafe.Pointer(&b)) +} + +func fetchTableName(emb []float32) (string, error) { + switch len(emb) { + case 384: + return "embeddings_384", nil + case 768: + return "embeddings_768", nil + case 1024: + return "embeddings_1024", nil + case 1536: + return "embeddings_1536", nil + case 2048: + return "embeddings_2048", nil + case 3072: + return "embeddings_3072", nil + case 4096: + return "embeddings_4096", nil + case 5120: + return "embeddings_5120", nil + default: + return "", fmt.Errorf("no table for the size of %d", len(emb)) + } +} + +func (p ProviderSQL) WriteVector(row *models.VectorRow) error { + tableName, err := fetchTableName(row.Embeddings) + if err != nil { + return err + } + serializedEmbeddings := SerializeVector(row.Embeddings) + query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) + _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) + return err +} + +func (p ProviderSQL) SearchClosest(q []float32, limit int) ([]models.VectorRow, error) { + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } + querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName + rows, err := p.db.Query(querySQL) + if err != nil { + return nil, err + } + defer rows.Close() + type SearchResult struct { + vector models.VectorRow + distance float32 + } + var allResults []SearchResult + for rows.Next() { + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(q, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + allResults = append(allResults, result) + } + // Sort by distance + sort.Slice(allResults, func(i, j int) bool { + return allResults[i].distance < allResults[j].distance + }) + // Truncate to limit + if len(allResults) > limit { + allResults = allResults[:limit] + } + // Convert back to VectorRow slice + results := make([]models.VectorRow, len(allResults)) + for i, result := range allResults { + result.vector.Distance = result.distance + results[i] = result.vector + } + return results, nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + if normA == 0 || normB == 0 { + return 0.0 + } + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 + } + return guess +} + +func (p ProviderSQL) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query all supported tables and combine results + tableNames := []string{ + "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", + "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", + } + for _, table := range tableNames { + query := "SELECT DISTINCT filename FROM " + table + rows, err := p.db.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } + } + return allFiles, nil +} + +func (p ProviderSQL) RemoveEmbByFileName(filename string) error { + var errors []string + tableNames := []string{ + "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", + "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", + } + for _, table := range tableNames { + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := p.db.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %v", errors) + } + return nil +} |
