summaryrefslogtreecommitdiff
path: root/storage
diff options
context:
space:
mode:
Diffstat (limited to 'storage')
-rw-r--r--storage/memory.go14
-rw-r--r--storage/migrate.go5
-rw-r--r--storage/migrations/001_init.up.sql1
-rw-r--r--storage/migrations/002_add_vector.down.sql34
-rw-r--r--storage/migrations/002_add_vector.up.sql98
-rw-r--r--storage/storage.go53
-rw-r--r--storage/storage_test.go34
-rw-r--r--storage/vector.go255
8 files changed, 474 insertions, 20 deletions
diff --git a/storage/memory.go b/storage/memory.go
index a7bf8cc..406182f 100644
--- a/storage/memory.go
+++ b/storage/memory.go
@@ -1,6 +1,6 @@
package storage
-import "elefant/models"
+import "gf-lt/models"
type Memories interface {
Memorise(m *models.Memory) (*models.Memory, error)
@@ -9,15 +9,23 @@ type Memories interface {
}
func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) {
- query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;"
+ query := `
+ INSERT INTO memories (agent, topic, mind)
+ VALUES (:agent, :topic, :mind)
+ ON CONFLICT (agent, topic) DO UPDATE
+ SET mind = excluded.mind,
+ updated_at = CURRENT_TIMESTAMP
+ RETURNING *;`
stmt, err := p.db.PrepareNamed(query)
if err != nil {
+ p.logger.Error("failed to prepare stmt", "query", query, "error", err)
return nil, err
}
defer stmt.Close()
var memory models.Memory
err = stmt.Get(&memory, m)
if err != nil {
+ p.logger.Error("failed to upsert memory", "query", query, "error", err)
return nil, err
}
return &memory, nil
@@ -28,6 +36,7 @@ func (p ProviderSQL) Recall(agent, topic string) (string, error) {
var mind string
err := p.db.Get(&mind, query, agent, topic)
if err != nil {
+ p.logger.Error("failed to get memory", "query", query, "error", err)
return "", err
}
return mind, nil
@@ -38,6 +47,7 @@ func (p ProviderSQL) RecallTopics(agent string) ([]string, error) {
var topics []string
err := p.db.Select(&topics, query, agent)
if err != nil {
+ p.logger.Error("failed to get topics", "query", query, "error", err)
return nil, err
}
return topics, nil
diff --git a/storage/migrate.go b/storage/migrate.go
index d97b99d..decfe9c 100644
--- a/storage/migrate.go
+++ b/storage/migrate.go
@@ -27,10 +27,11 @@ func (p *ProviderSQL) Migrate() {
err := p.executeMigration(migrationsDir, file.Name())
if err != nil {
p.logger.Error("Failed to execute migration %s: %v", file.Name(), err)
+ panic(err)
}
}
}
- p.logger.Info("All migrations executed successfully!")
+ p.logger.Debug("All migrations executed successfully!")
}
func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error {
@@ -50,7 +51,7 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err
}
func (p *ProviderSQL) executeSQL(sqlContent []byte) error {
- // Connect to the database (example using a simple connection)
+ // Execute the migration content using standard database connection
_, err := p.db.Exec(string(sqlContent))
if err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql
index 8980ccf..09bb5e6 100644
--- a/storage/migrations/001_init.up.sql
+++ b/storage/migrations/001_init.up.sql
@@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS chats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
msgs TEXT NOT NULL,
+ agent TEXT NOT NULL DEFAULT 'assistant',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
diff --git a/storage/migrations/002_add_vector.down.sql b/storage/migrations/002_add_vector.down.sql
new file mode 100644
index 0000000..a257b11
--- /dev/null
+++ b/storage/migrations/002_add_vector.down.sql
@@ -0,0 +1,34 @@
+-- Drop vector storage tables
+DROP INDEX IF EXISTS idx_embeddings_384_filename;
+DROP INDEX IF EXISTS idx_embeddings_768_filename;
+DROP INDEX IF EXISTS idx_embeddings_1024_filename;
+DROP INDEX IF EXISTS idx_embeddings_1536_filename;
+DROP INDEX IF EXISTS idx_embeddings_2048_filename;
+DROP INDEX IF EXISTS idx_embeddings_3072_filename;
+DROP INDEX IF EXISTS idx_embeddings_4096_filename;
+DROP INDEX IF EXISTS idx_embeddings_5120_filename;
+DROP INDEX IF EXISTS idx_embeddings_384_slug;
+DROP INDEX IF EXISTS idx_embeddings_768_slug;
+DROP INDEX IF EXISTS idx_embeddings_1024_slug;
+DROP INDEX IF EXISTS idx_embeddings_1536_slug;
+DROP INDEX IF EXISTS idx_embeddings_2048_slug;
+DROP INDEX IF EXISTS idx_embeddings_3072_slug;
+DROP INDEX IF EXISTS idx_embeddings_4096_slug;
+DROP INDEX IF EXISTS idx_embeddings_5120_slug;
+DROP INDEX IF EXISTS idx_embeddings_384_created_at;
+DROP INDEX IF EXISTS idx_embeddings_768_created_at;
+DROP INDEX IF EXISTS idx_embeddings_1024_created_at;
+DROP INDEX IF EXISTS idx_embeddings_1536_created_at;
+DROP INDEX IF EXISTS idx_embeddings_2048_created_at;
+DROP INDEX IF EXISTS idx_embeddings_3072_created_at;
+DROP INDEX IF EXISTS idx_embeddings_4096_created_at;
+DROP INDEX IF EXISTS idx_embeddings_5120_created_at;
+
+DROP TABLE IF EXISTS embeddings_384;
+DROP TABLE IF EXISTS embeddings_768;
+DROP TABLE IF EXISTS embeddings_1024;
+DROP TABLE IF EXISTS embeddings_1536;
+DROP TABLE IF EXISTS embeddings_2048;
+DROP TABLE IF EXISTS embeddings_3072;
+DROP TABLE IF EXISTS embeddings_4096;
+DROP TABLE IF EXISTS embeddings_5120; \ No newline at end of file
diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql
new file mode 100644
index 0000000..baf703d
--- /dev/null
+++ b/storage/migrations/002_add_vector.up.sql
@@ -0,0 +1,98 @@
+-- Create tables for vector storage (replacing vec0 plugin usage)
+CREATE TABLE IF NOT EXISTS embeddings_384 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_768 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_1024 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_1536 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_2048 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_3072 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_4096 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS embeddings_5120 (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ embeddings BLOB NOT NULL,
+ slug TEXT NOT NULL,
+ raw_text TEXT NOT NULL,
+ filename TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+);
+
+-- Indexes for better performance
+CREATE INDEX IF NOT EXISTS idx_embeddings_384_filename ON embeddings_384(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_768_filename ON embeddings_768(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_1024_filename ON embeddings_1024(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_1536_filename ON embeddings_1536(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_2048_filename ON embeddings_2048(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_3072_filename ON embeddings_3072(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_4096_filename ON embeddings_4096(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_5120_filename ON embeddings_5120(filename);
+CREATE INDEX IF NOT EXISTS idx_embeddings_384_slug ON embeddings_384(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_768_slug ON embeddings_768(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_1024_slug ON embeddings_1024(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_1536_slug ON embeddings_1536(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_2048_slug ON embeddings_2048(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_3072_slug ON embeddings_3072(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_4096_slug ON embeddings_4096(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_5120_slug ON embeddings_5120(slug);
+CREATE INDEX IF NOT EXISTS idx_embeddings_384_created_at ON embeddings_384(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_768_created_at ON embeddings_768(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_1024_created_at ON embeddings_1024(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_1536_created_at ON embeddings_1536(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_2048_created_at ON embeddings_2048(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_3072_created_at ON embeddings_3072(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_4096_created_at ON embeddings_4096(created_at);
+CREATE INDEX IF NOT EXISTS idx_embeddings_5120_created_at ON embeddings_5120(created_at);
diff --git a/storage/storage.go b/storage/storage.go
index 67b8dd8..a092f8d 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -1,19 +1,28 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"log/slog"
_ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx"
)
+type FullRepo interface {
+ ChatHistory
+ Memories
+ VectorRepo
+}
+
type ChatHistory interface {
ListChats() ([]models.Chat, error)
GetChatByID(id uint32) (*models.Chat, error)
+ GetChatByChar(char string) ([]models.Chat, error)
GetLastChat() (*models.Chat, error)
+ GetLastChatByAgent(agent string) (*models.Chat, error)
UpsertChat(chat *models.Chat) (*models.Chat, error)
RemoveChat(id uint32) error
+ ChatGetMaxID() (uint32, error)
}
type ProviderSQL struct {
@@ -27,6 +36,12 @@ func (p ProviderSQL) ListChats() ([]models.Chat, error) {
return resp, err
}
+func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) {
+ resp := []models.Chat{}
+ err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char)
+ return resp, err
+}
+
func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) {
resp := models.Chat{}
err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id)
@@ -39,16 +54,28 @@ func (p ProviderSQL) GetLastChat() (*models.Chat, error) {
return &resp, err
}
+func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) {
+ resp := models.Chat{}
+ query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1"
+ err := p.db.Get(&resp, query, agent)
+ return &resp, err
+}
+
+// https://sqlite.org/lang_upsert.html
+// on conflict was added
func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) {
// Prepare the SQL statement
query := `
- INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at)
- VALUES (:id, :name, :msgs, :created_at, :updated_at)
+ INSERT INTO chats (id, name, msgs, agent, created_at, updated_at)
+ VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at)
+ ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs,
+ updated_at=excluded.updated_at
RETURNING *;`
stmt, err := p.db.PrepareNamed(query)
if err != nil {
return nil, err
}
+ defer stmt.Close()
// Execute the query and scan the result into a new chat object
var resp models.Chat
err = stmt.Get(&resp, chat)
@@ -61,13 +88,27 @@ func (p ProviderSQL) RemoveChat(id uint32) error {
return err
}
-func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory {
+func (p ProviderSQL) ChatGetMaxID() (uint32, error) {
+ query := "SELECT MAX(id) FROM chats;"
+ var id uint32
+ err := p.db.Get(&id, query)
+ return id, err
+}
+
+// opens database connection
+func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
db, err := sqlx.Open("sqlite", dbPath)
if err != nil {
- panic(err)
+ logger.Error("failed to open db connection", "error", err)
+ return nil
}
- // get SQLite version
p := ProviderSQL{db: db, logger: logger}
+
p.Migrate()
return p
}
+
+// DB returns the underlying database connection
+func (p ProviderSQL) DB() *sqlx.DB {
+ return p.db
+}
diff --git a/storage/storage_test.go b/storage/storage_test.go
index ad1f1bf..a4f2bdd 100644
--- a/storage/storage_test.go
+++ b/storage/storage_test.go
@@ -1,8 +1,8 @@
package storage
import (
- "elefant/models"
"fmt"
+ "gf-lt/models"
"log/slog"
"os"
"testing"
@@ -35,22 +35,27 @@ CREATE TABLE IF NOT EXISTS memories (
logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)),
}
// Create a sample memory for testing
- sampleMemory := &models.Memory{
+ sampleMemory := models.Memory{
Agent: "testAgent",
Topic: "testTopic",
Mind: "testMind",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
+ sampleMemoryRewrite := models.Memory{
+ Agent: "testAgent",
+ Topic: "testTopic",
+ Mind: "same topic, new mind",
+ }
cases := []struct {
- memory *models.Memory
+ memories []models.Memory
}{
- {memory: sampleMemory},
+ {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) {
// Recall topics: get no rows
- topics, err := provider.RecallTopics(tc.memory.Agent)
+ topics, err := provider.RecallTopics(tc.memories[0].Agent)
if err != nil {
t.Fatalf("Failed to recall topics: %v", err)
}
@@ -58,12 +63,12 @@ CREATE TABLE IF NOT EXISTS memories (
t.Fatalf("Expected no topics, got: %v", topics)
}
// Memorise
- _, err = provider.Memorise(tc.memory)
+ _, err = provider.Memorise(&tc.memories[0])
if err != nil {
t.Fatalf("Failed to memorise: %v", err)
}
// Recall topics: has topics
- topics, err = provider.RecallTopics(tc.memory.Agent)
+ topics, err = provider.RecallTopics(tc.memories[0].Agent)
if err != nil {
t.Fatalf("Failed to recall topics: %v", err)
}
@@ -71,12 +76,20 @@ CREATE TABLE IF NOT EXISTS memories (
t.Fatalf("Expected topics, got none")
}
// Recall
- content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic)
+ content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic)
if err != nil {
t.Fatalf("Failed to recall: %v", err)
}
- if content != tc.memory.Mind {
- t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content)
+ if content != tc.memories[0].Mind {
+ t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content)
+ }
+ // rewrite mind of same agent-topic
+ newMem, err := provider.Memorise(&tc.memories[1])
+ if err != nil {
+ t.Fatalf("Failed to memorise: %v", err)
+ }
+ if newMem.Mind == tc.memories[0].Mind {
+ t.Fatalf("Failed to change mind: %v", newMem.Mind)
}
})
}
@@ -95,6 +108,7 @@ func TestChatHistory(t *testing.T) {
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
msgs TEXT NOT NULL,
+ agent TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);`)
diff --git a/storage/vector.go b/storage/vector.go
new file mode 100644
index 0000000..32b4731
--- /dev/null
+++ b/storage/vector.go
@@ -0,0 +1,255 @@
+package storage
+
+import (
+ "encoding/binary"
+ "fmt"
+ "gf-lt/models"
+ "unsafe"
+
+ "github.com/jmoiron/sqlx"
+)
+
+type VectorRepo interface {
+ WriteVector(*models.VectorRow) error
+ SearchClosest(q []float32) ([]models.VectorRow, error)
+ ListFiles() ([]string, error)
+ RemoveEmbByFileName(filename string) error
+ DB() *sqlx.DB
+}
+
+// SerializeVector converts []float32 to binary blob
+func SerializeVector(vec []float32) []byte {
+ buf := make([]byte, len(vec)*4) // 4 bytes per float32
+ for i, v := range vec {
+ binary.LittleEndian.PutUint32(buf[i*4:], mathFloat32bits(v))
+ }
+ return buf
+}
+
+// DeserializeVector converts binary blob back to []float32
+func DeserializeVector(data []byte) []float32 {
+ count := len(data) / 4
+ vec := make([]float32, count)
+ for i := 0; i < count; i++ {
+ vec[i] = mathBitsToFloat32(binary.LittleEndian.Uint32(data[i*4:]))
+ }
+ return vec
+}
+
+// mathFloat32bits and mathBitsToFloat32 are helpers to convert between float32 and uint32
+func mathFloat32bits(f float32) uint32 {
+ return binary.LittleEndian.Uint32((*(*[4]byte)(unsafe.Pointer(&f)))[:4])
+}
+
+func mathBitsToFloat32(b uint32) float32 {
+ return *(*float32)(unsafe.Pointer(&b))
+}
+
+func fetchTableName(emb []float32) (string, error) {
+ switch len(emb) {
+ case 384:
+ return "embeddings_384", nil
+ case 768:
+ return "embeddings_768", nil
+ case 1024:
+ return "embeddings_1024", nil
+ case 1536:
+ return "embeddings_1536", nil
+ case 2048:
+ return "embeddings_2048", nil
+ case 3072:
+ return "embeddings_3072", nil
+ case 4096:
+ return "embeddings_4096", nil
+ case 5120:
+ return "embeddings_5120", nil
+ default:
+ return "", fmt.Errorf("no table for the size of %d", len(emb))
+ }
+}
+
+func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
+ tableName, err := fetchTableName(row.Embeddings)
+ if err != nil {
+ return err
+ }
+
+ serializedEmbeddings := SerializeVector(row.Embeddings)
+
+ query := fmt.Sprintf("INSERT INTO %s(embeddings, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)
+ _, err = p.db.Exec(query, serializedEmbeddings, row.Slug, row.RawText, row.FileName)
+
+ return err
+}
+
+func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
+ tableName, err := fetchTableName(q)
+ if err != nil {
+ return nil, err
+ }
+
+ querySQL := "SELECT embeddings, slug, raw_text, filename FROM " + tableName
+ rows, err := p.db.Query(querySQL)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ type SearchResult struct {
+ vector models.VectorRow
+ distance float32
+ }
+
+ var topResults []SearchResult
+
+ for rows.Next() {
+ var (
+ embeddingsBlob []byte
+ slug, rawText, fileName string
+ )
+
+ if err := rows.Scan(&embeddingsBlob, &slug, &rawText, &fileName); err != nil {
+ continue
+ }
+
+ storedEmbeddings := DeserializeVector(embeddingsBlob)
+
+ // Calculate cosine similarity (returns value between -1 and 1, where 1 is most similar)
+ similarity := cosineSimilarity(q, storedEmbeddings)
+ distance := 1 - similarity // Convert to distance where 0 is most similar
+
+ result := SearchResult{
+ vector: models.VectorRow{
+ Embeddings: storedEmbeddings,
+ Slug: slug,
+ RawText: rawText,
+ FileName: fileName,
+ },
+ distance: distance,
+ }
+
+ // Add to top results and maintain only top results
+ topResults = append(topResults, result)
+
+ // Sort and keep only top results
+ // We'll keep the top 3 closest vectors
+ if len(topResults) > 3 {
+ // Simple sort and truncate to maintain only 3 best matches
+ for i := 0; i < len(topResults); i++ {
+ for j := i + 1; j < len(topResults); j++ {
+ if topResults[i].distance > topResults[j].distance {
+ topResults[i], topResults[j] = topResults[j], topResults[i]
+ }
+ }
+ }
+ topResults = topResults[:3]
+ }
+ }
+
+ // Convert back to VectorRow slice
+ results := make([]models.VectorRow, len(topResults))
+ for i, result := range topResults {
+ result.vector.Distance = result.distance
+ results[i] = result.vector
+ }
+
+ return results, nil
+}
+
+// cosineSimilarity calculates the cosine similarity between two vectors
+func cosineSimilarity(a, b []float32) float32 {
+ if len(a) != len(b) {
+ return 0.0
+ }
+
+ var dotProduct, normA, normB float32
+ for i := 0; i < len(a); i++ {
+ dotProduct += a[i] * b[i]
+ normA += a[i] * a[i]
+ normB += b[i] * b[i]
+ }
+
+ if normA == 0 || normB == 0 {
+ return 0.0
+ }
+
+ return dotProduct / (sqrt(normA) * sqrt(normB))
+}
+
+// sqrt returns the square root of a float32
+func sqrt(f float32) float32 {
+ // A simple implementation of square root using Newton's method
+ if f == 0 {
+ return 0
+ }
+ guess := f / 2
+ for i := 0; i < 10; i++ { // 10 iterations should be enough for good precision
+ guess = (guess + f/guess) / 2
+ }
+ return guess
+}
+
+func (p ProviderSQL) ListFiles() ([]string, error) {
+ fileLists := make([][]string, 0)
+
+ // Query all supported tables and combine results
+ tableNames := []string{
+ "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
+ "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
+ }
+ for _, table := range tableNames {
+ query := "SELECT DISTINCT filename FROM " + table
+ rows, err := p.db.Query(query)
+ if err != nil {
+ // Continue if one table doesn't exist
+ continue
+ }
+
+ var files []string
+ for rows.Next() {
+ var filename string
+ if err := rows.Scan(&filename); err != nil {
+ continue
+ }
+ files = append(files, filename)
+ }
+ rows.Close()
+
+ fileLists = append(fileLists, files)
+ }
+
+ // Combine and deduplicate
+ fileSet := make(map[string]bool)
+ var allFiles []string
+ for _, files := range fileLists {
+ for _, file := range files {
+ if !fileSet[file] {
+ fileSet[file] = true
+ allFiles = append(allFiles, file)
+ }
+ }
+ }
+
+ return allFiles, nil
+}
+
+func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
+ var errors []string
+
+ tableNames := []string{
+ "embeddings_384", "embeddings_768", "embeddings_1024", "embeddings_1536",
+ "embeddings_2048", "embeddings_3072", "embeddings_4096", "embeddings_5120",
+ }
+ for _, table := range tableNames {
+ query := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", table)
+ if _, err := p.db.Exec(query, filename); err != nil {
+ errors = append(errors, err.Error())
+ }
+ }
+
+ if len(errors) > 0 {
+ return fmt.Errorf("errors occurred: %v", errors)
+ }
+
+ return nil
+}