diff options
Diffstat (limited to 'storage')
| -rw-r--r-- | storage/memory.go | 14 | ||||
| -rw-r--r-- | storage/migrate.go | 5 | ||||
| -rw-r--r-- | storage/migrations/001_init.up.sql | 1 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.down.sql | 34 | ||||
| -rw-r--r-- | storage/migrations/002_add_vector.up.sql | 98 | ||||
| -rw-r--r-- | storage/storage.go | 53 | ||||
| -rw-r--r-- | storage/storage_test.go | 34 | ||||
| -rw-r--r-- | storage/vector.go | 255 |
8 files changed, 474 insertions, 20 deletions
diff --git a/storage/memory.go b/storage/memory.go index a7bf8cc..406182f 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -1,6 +1,6 @@ package storage -import "elefant/models" +import "gf-lt/models" type Memories interface { Memorise(m *models.Memory) (*models.Memory, error) @@ -9,15 +9,23 @@ type Memories interface { } func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { - query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;" + query := ` + INSERT INTO memories (agent, topic, mind) + VALUES (:agent, :topic, :mind) + ON CONFLICT (agent, topic) DO UPDATE + SET mind = excluded.mind, + updated_at = CURRENT_TIMESTAMP + RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { + p.logger.Error("failed to prepare stmt", "query", query, "error", err) return nil, err } defer stmt.Close() var memory models.Memory err = stmt.Get(&memory, m) if err != nil { + p.logger.Error("failed to upsert memory", "query", query, "error", err) return nil, err } return &memory, nil @@ -28,6 +36,7 @@ func (p ProviderSQL) Recall(agent, topic string) (string, error) { var mind string err := p.db.Get(&mind, query, agent, topic) if err != nil { + p.logger.Error("failed to get memory", "query", query, "error", err) return "", err } return mind, nil @@ -38,6 +47,7 @@ func (p ProviderSQL) RecallTopics(agent string) ([]string, error) { var topics []string err := p.db.Select(&topics, query, agent) if err != nil { + p.logger.Error("failed to get topics", "query", query, "error", err) return nil, err } return topics, nil diff --git a/storage/migrate.go b/storage/migrate.go index d97b99d..decfe9c 100644 --- a/storage/migrate.go +++ b/storage/migrate.go @@ -27,10 +27,11 @@ func (p *ProviderSQL) Migrate() { err := p.executeMigration(migrationsDir, file.Name()) if err != nil { p.logger.Error("Failed to execute migration %s: %v", file.Name(), err) + panic(err) } } } - p.logger.Info("All migrations executed successfully!") + p.logger.Debug("All migrations executed successfully!") } func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error { @@ -50,7 +51,7 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err } func (p *ProviderSQL) executeSQL(sqlContent []byte) error { - // Connect to the database (example using a simple connection) + // Execute the migration content using standard database connection _, err := p.db.Exec(string(sqlContent)) if err != nil { return fmt.Errorf("failed to execute SQL: %w", err) diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql index 8980ccf..09bb5e6 100644 --- a/storage/migrations/001_init.up.sql +++ b/storage/migrations/001_init.up.sql @@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL DEFAULT 'assistant', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); diff --git a/storage/migrations/002_add_vector.down.sql b/storage/migrations/002_add_vector.down.sql new file mode 100644 index 0000000..a257b11 --- /dev/null +++ b/storage/migrations/002_add_vector.down.sql @@ -0,0 +1,34 @@ +-- Drop vector storage tables +DROP INDEX IF EXISTS idx_embeddings_384_filename; +DROP INDEX IF EXISTS idx_embeddings_768_filename; +DROP INDEX IF EXISTS idx_embeddings_1024_filename; +DROP INDEX IF EXISTS idx_embeddings_1536_filename; +DROP INDEX IF EXISTS idx_embeddings_2048_filename; +DROP INDEX IF EXISTS idx_embeddings_3072_filename; +DROP INDEX IF EXISTS idx_embeddings_4096_filename; +DROP INDEX IF EXISTS idx_embeddings_5120_filename; +DROP INDEX IF EXISTS idx_embeddings_384_slug; +DROP INDEX IF EXISTS idx_embeddings_768_slug; +DROP INDEX IF EXISTS idx_embeddings_1024_slug; +DROP INDEX IF EXISTS idx_embeddings_1536_slug; +DROP INDEX IF EXISTS idx_embeddings_2048_slug; +DROP INDEX IF EXISTS idx_embeddings_3072_slug; +DROP INDEX IF EXISTS idx_embeddings_4096_slug; +DROP INDEX IF EXISTS idx_embeddings_5120_slug; +DROP INDEX IF EXISTS idx_embeddings_384_created_at; +DROP INDEX IF EXISTS idx_embeddings_768_created_at; +DROP INDEX IF EXISTS idx_embeddings_1024_created_at; +DROP INDEX IF EXISTS idx_embeddings_1536_created_at; +DROP INDEX IF EXISTS idx_embeddings_2048_created_at; +DROP INDEX IF EXISTS idx_embeddings_3072_created_at; +DROP INDEX IF EXISTS idx_embeddings_4096_created_at; +DROP INDEX IF EXISTS idx_embeddings_5120_created_at; + +DROP TABLE IF EXISTS embeddings_384; +DROP TABLE IF EXISTS embeddings_768; +DROP TABLE IF EXISTS embeddings_1024; +DROP TABLE IF EXISTS embeddings_1536; +DROP TABLE IF EXISTS embeddings_2048; +DROP TABLE IF EXISTS embeddings_3072; +DROP TABLE IF EXISTS embeddings_4096; +DROP TABLE IF EXISTS embeddings_5120;
\ No newline at end of file diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql new file mode 100644 index 0000000..baf703d --- /dev/null +++ b/storage/migrations/002_add_vector.up.sql @@ -0,0 +1,98 @@ +-- Create tables for vector storage (replacing vec0 plugin usage) +CREATE TABLE IF NOT EXISTS embeddings_384 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_768 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1024 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_1536 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_2048 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_3072 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_4096 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS embeddings_5120 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embeddings BLOB NOT NULL, + slug TEXT NOT NULL, + raw_text TEXT NOT NULL, + filename TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Indexes for better performance +CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_filename ON embeddings_768(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_filename ON embeddings_1024(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_filename ON embeddings_1536(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_filename ON embeddings_2048(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_filename ON embeddings_3072(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_filename ON embeddings_4096(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_slug ON embeddings_768(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_slug ON embeddings_1024(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_slug ON embeddings_1536(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_slug ON embeddings_2048(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_slug ON embeddings_3072(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_slug ON embeddings_4096(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug); +CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_768_created_at ON embeddings_768(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1024_created_at ON embeddings_1024(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_1536_created_at ON embeddings_1536(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_2048_created_at ON embeddings_2048(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_3072_created_at ON embeddings_3072(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_4096_created_at ON embeddings_4096(created_at); +CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at); diff --git a/storage/storage.go b/storage/storage.go index 67b8dd8..a092f8d 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,19 +1,28 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" ) +type FullRepo interface { + ChatHistory + Memories + VectorRepo +} + type ChatHistory interface { ListChats() ([]models.Chat, error) GetChatByID(id uint32) (*models.Chat, error) + GetChatByChar(char string) ([]models.Chat, error) GetLastChat() (*models.Chat, error) + GetLastChatByAgent(agent string) (*models.Chat, error) UpsertChat(chat *models.Chat) (*models.Chat, error) RemoveChat(id uint32) error + ChatGetMaxID() (uint32, error) } type ProviderSQL struct { @@ -27,6 +36,12 @@ func (p ProviderSQL) ListChats() ([]models.Chat, error) { return resp, err } +func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) { + resp := []models.Chat{} + err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char) + return resp, err +} + func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { resp := models.Chat{} err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id) @@ -39,16 +54,28 @@ func (p ProviderSQL) GetLastChat() (*models.Chat, error) { return &resp, err } +func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) { + resp := models.Chat{} + query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1" + err := p.db.Get(&resp, query, agent) + return &resp, err +} + +// https://sqlite.org/lang_upsert.html +// on conflict was added func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { // Prepare the SQL statement query := ` - INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at) - VALUES (:id, :name, :msgs, :created_at, :updated_at) + INSERT INTO chats (id, name, msgs, agent, created_at, updated_at) + VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at) + ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs, + updated_at=excluded.updated_at RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { return nil, err } + defer stmt.Close() // Execute the query and scan the result into a new chat object var resp models.Chat err = stmt.Get(&resp, chat) @@ -61,13 +88,27 @@ func (p ProviderSQL) RemoveChat(id uint32) error { return err } -func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory { +func (p ProviderSQL) ChatGetMaxID() (uint32, error) { + query := "SELECT MAX(id) FROM chats;" + var id uint32 + err := p.db.Get(&id, query) + return id, err +} + +// opens database connection +func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { - panic(err) + logger.Error("failed to open db connection", "error", err) + return nil } - // get SQLite version p := ProviderSQL{db: db, logger: logger} + p.Migrate() return p } + +// DB returns the underlying database connection +func (p ProviderSQL) DB() *sqlx.DB { + return p.db +} diff --git a/storage/storage_test.go b/storage/storage_test.go index ad1f1bf..a4f2bdd 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,8 +1,8 @@ package storage import ( - "elefant/models" "fmt" + "gf-lt/models" "log/slog" "os" "testing" @@ -35,22 +35,27 @@ CREATE TABLE IF NOT EXISTS memories ( logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), } // Create a sample memory for testing - sampleMemory := &models.Memory{ + sampleMemory := models.Memory{ Agent: "testAgent", Topic: "testTopic", Mind: "testMind", CreatedAt: time.Now(), UpdatedAt: time.Now(), } + sampleMemoryRewrite := models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "same topic, new mind", + } cases := []struct { - memory *models.Memory + memories []models.Memory }{ - {memory: sampleMemory}, + {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}}, } for i, tc := range cases { t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { // Recall topics: get no rows - topics, err := provider.RecallTopics(tc.memory.Agent) + topics, err := provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -58,12 +63,12 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected no topics, got: %v", topics) } // Memorise - _, err = provider.Memorise(tc.memory) + _, err = provider.Memorise(&tc.memories[0]) if err != nil { t.Fatalf("Failed to memorise: %v", err) } // Recall topics: has topics - topics, err = provider.RecallTopics(tc.memory.Agent) + topics, err = provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -71,12 +76,20 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected topics, got none") } // Recall - content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic) if err != nil { t.Fatalf("Failed to recall: %v", err) } - if content != tc.memory.Mind { - t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + if content != tc.memories[0].Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content) + } + // rewrite mind of same agent-topic + newMem, err := provider.Memorise(&tc.memories[1]) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + if newMem.Mind == tc.memories[0].Mind { + t.Fatalf("Failed to change mind: %v", newMem.Mind) } }) } @@ -95,6 +108,7 @@ func TestChatHistory(t *testing.T) { id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`) diff --git a/storage/vector.go b/storage/vector.go new file mode 100644 index 0000000..32b4731 --- /dev/null +++ b/storage/vector.go @@ -0,0 +1,255 @@ +package storage + +import ( + "encoding/binary" + "fmt" + "gf-lt/models" + "unsafe" + + "github.com/jmoiron/sqlx" +) + +type VectorRepo interface { + WriteVector(*models.VectorRow) error + SearchClosest(q []float32) ([]models.VectorRow, error) + ListFiles() ([]string, error) + RemoveEmbByFileName(filename string) error + DB() *sqlx.DB +} + +// SerializeVector converts []float32 to binary blob +func SerializeVector(vec []float32) []byte { + buf := make([]byte, len(vec)*4) // 4 bytes per float32 + for i, v := range vec { + binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v)) + } + return buf +} + +// DeserializeVector converts binary blob back to []float32 +func DeserializeVector(data []byte) []float32 { + count := len(data) / 4 + vec := make([]float32, count) + for i := 0; i < count; i++ { + vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:])) + } + return vec +} + +// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32 +func mathFloat32bits(f float32) uint32 { + return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4]) +} + +func mathBitsToFloat32(b uint32) float32 { + return *(*float32)(unsafe.Pointer(&b)) +} + +func fetchTableName(emb []float32) (string, error) { + switch len(emb) { + case 384: + return "embeddings_384", nil + case 768: + return "embeddings_768", nil + case 1024: + return "embeddings_1024", nil + case 1536: + return "embeddings_1536", nil + case 2048: + return "embeddings_2048", nil + case 3072: + return "embeddings_3072", nil + case 4096: + return "embeddings_4096", nil + case 5120: + return "embeddings_5120", nil + default: + return "", fmt.Errorf("no table for the size of %d", len(emb)) + } +} + +func (p ProviderSQL) WriteVector(row *models.VectorRow) error { + tableName, err := fetchTableName(row.Embeddings) + if err != nil { + return err + } + + serializedEmbeddings := SerializeVector(row.Embeddings) + + query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName) + _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName) + + return err +} + +func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } + + querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName + rows, err := p.db.Query(querySQL) + if err != nil { + return nil, err + } + defer rows.Close() + + type SearchResult struct { + vector models.VectorRow + distance float32 + } + + var topResults []SearchResult + + for rows.Next() { + var ( + embeddingsBlob []byte + slug, rawText, fileName string + ) + + if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil { + continue + } + + storedEmbeddings := DeserializeVector(embeddingsBlob) + + // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar) + similarity := cosineSimilarity(q, storedEmbeddings) + distance := 1 - similarity // Convert to distance where 0 is most similar + + result := SearchResult{ + vector: models.VectorRow{ + Embeddings: storedEmbeddings, + Slug: slug, + RawText: rawText, + FileName: fileName, + }, + distance: distance, + } + + // Add to top results and maintain only top results + topResults = append(topResults, result) + + // Sort and keep only top results + // We'll keep the top 3 closest vectors + if len(topResults) > 3 { + // Simple sort and truncate to maintain only 3 best matches + for i := 0; i < len(topResults); i++ { + for j := i + 1; j < len(topResults); j++ { + if topResults[i].distance > topResults[j].distance { + topResults[i], topResults[j] = topResults[j], topResults[i] + } + } + } + topResults = topResults[:3] + } + } + + // Convert back to VectorRow slice + results := make([]models.VectorRow, len(topResults)) + for i, result := range topResults { + result.vector.Distance = result.distance + results[i] = result.vector + } + + return results, nil +} + +// cosineSimilarity calculates the cosine similarity between two vectors +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float32 + for i := 0; i < len(a); i++ { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) +} + +// sqrt returns the square root of a float32 +func sqrt(f float32) float32 { + // A simple implementation of square root using Newton's method + if f == 0 { + return 0 + } + guess := f / 2 + for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision + guess = (guess + f/guess) / 2 + } + return guess +} + +func (p ProviderSQL) ListFiles() ([]string, error) { + fileLists := make([][]string, 0) + + // Query all supported tables and combine results + tableNames := []string{ + "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", + "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", + } + for _, table := range tableNames { + query := "SELECT DISTINCT filename FROM " + table + rows, err := p.db.Query(query) + if err != nil { + // Continue if one table doesn't exist + continue + } + + var files []string + for rows.Next() { + var filename string + if err := rows.Scan(&filename); err != nil { + continue + } + files = append(files, filename) + } + rows.Close() + + fileLists = append(fileLists, files) + } + + // Combine and deduplicate + fileSet := make(map[string]bool) + var allFiles []string + for _, files := range fileLists { + for _, file := range files { + if !fileSet[file] { + fileSet[file] = true + allFiles = append(allFiles, file) + } + } + } + + return allFiles, nil +} + +func (p ProviderSQL) RemoveEmbByFileName(filename string) error { + var errors []string + + tableNames := []string{ + "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536", + "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120", + } + for _, table := range tableNames { + query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table) + if _, err := p.db.Exec(query, filename); err != nil { + errors = append(errors, err.Error()) + } + } + + if len(errors) > 0 { + return fmt.Errorf("errors occurred: %v", errors) + } + + return nil +} |
