diff options
Diffstat (limited to 'storage')
-rw-r--r-- | storage/memory.go | 14 | ||||
-rw-r--r-- | storage/migrate.go | 7 | ||||
-rw-r--r-- | storage/migrations/001_init.up.sql | 1 | ||||
-rw-r--r-- | storage/migrations/002_add_vector.up.sql | 12 | ||||
-rw-r--r-- | storage/storage.go | 54 | ||||
-rw-r--r-- | storage/storage_test.go | 119 | ||||
-rw-r--r-- | storage/vector.go | 163 |
7 files changed, 350 insertions, 20 deletions
diff --git a/storage/memory.go b/storage/memory.go index a7bf8cc..406182f 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -1,6 +1,6 @@ package storage -import "elefant/models" +import "gf-lt/models" type Memories interface { Memorise(m *models.Memory) (*models.Memory, error) @@ -9,15 +9,23 @@ type Memories interface { } func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { - query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;" + query := ` + INSERT INTO memories (agent, topic, mind) + VALUES (:agent, :topic, :mind) + ON CONFLICT (agent, topic) DO UPDATE + SET mind = excluded.mind, + updated_at = CURRENT_TIMESTAMP + RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { + p.logger.Error("failed to prepare stmt", "query", query, "error", err) return nil, err } defer stmt.Close() var memory models.Memory err = stmt.Get(&memory, m) if err != nil { + p.logger.Error("failed to upsert memory", "query", query, "error", err) return nil, err } return &memory, nil @@ -28,6 +36,7 @@ func (p ProviderSQL) Recall(agent, topic string) (string, error) { var mind string err := p.db.Get(&mind, query, agent, topic) if err != nil { + p.logger.Error("failed to get memory", "query", query, "error", err) return "", err } return mind, nil @@ -38,6 +47,7 @@ func (p ProviderSQL) RecallTopics(agent string) ([]string, error) { var topics []string err := p.db.Select(&topics, query, agent) if err != nil { + p.logger.Error("failed to get topics", "query", query, "error", err) return nil, err } return topics, nil diff --git a/storage/migrate.go b/storage/migrate.go index d97b99d..b05dddc 100644 --- a/storage/migrate.go +++ b/storage/migrate.go @@ -5,6 +5,8 @@ import ( "fmt" "io/fs" "strings" + + _ "github.com/asg017/sqlite-vec-go-bindings/ncruces" ) //go:embed migrations/* @@ -27,10 +29,11 @@ func (p *ProviderSQL) Migrate() { err := p.executeMigration(migrationsDir, file.Name()) if err != nil { p.logger.Error("Failed to execute migration %s: %v", file.Name(), err) + panic(err) } } } - p.logger.Info("All migrations executed successfully!") + p.logger.Debug("All migrations executed successfully!") } func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error { @@ -51,7 +54,7 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err func (p *ProviderSQL) executeSQL(sqlContent []byte) error { // Connect to the database (example using a simple connection) - _, err := p.db.Exec(string(sqlContent)) + err := p.s3Conn.Exec(string(sqlContent)) if err != nil { return fmt.Errorf("failed to execute SQL: %w", err) } diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql index 8980ccf..09bb5e6 100644 --- a/storage/migrations/001_init.up.sql +++ b/storage/migrations/001_init.up.sql @@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL DEFAULT 'assistant', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql new file mode 100644 index 0000000..2ac4621 --- /dev/null +++ b/storage/migrations/002_add_vector.up.sql @@ -0,0 +1,12 @@ +--CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_5120 USING vec0( +-- embedding FLOAT[5120], +-- slug TEXT NOT NULL, +-- raw_text TEXT PRIMARY KEY, +--); + +CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0( + embedding FLOAT[384], + slug TEXT NOT NULL, + raw_text TEXT PRIMARY KEY, + filename TEXT NOT NULL DEFAULT '' +); diff --git a/storage/storage.go b/storage/storage.go index 67b8dd8..7911e13 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,23 +1,34 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" + "github.com/ncruces/go-sqlite3" ) +type FullRepo interface { + ChatHistory + Memories + VectorRepo +} + type ChatHistory interface { ListChats() ([]models.Chat, error) GetChatByID(id uint32) (*models.Chat, error) + GetChatByChar(char string) ([]models.Chat, error) GetLastChat() (*models.Chat, error) + GetLastChatByAgent(agent string) (*models.Chat, error) UpsertChat(chat *models.Chat) (*models.Chat, error) RemoveChat(id uint32) error + ChatGetMaxID() (uint32, error) } type ProviderSQL struct { db *sqlx.DB + s3Conn *sqlite3.Conn logger *slog.Logger } @@ -27,6 +38,12 @@ func (p ProviderSQL) ListChats() ([]models.Chat, error) { return resp, err } +func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) { + resp := []models.Chat{} + err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char) + return resp, err +} + func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { resp := models.Chat{} err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id) @@ -39,16 +56,28 @@ func (p ProviderSQL) GetLastChat() (*models.Chat, error) { return &resp, err } +func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) { + resp := models.Chat{} + query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1" + err := p.db.Get(&resp, query, agent) + return &resp, err +} + +// https://sqlite.org/lang_upsert.html +// on conflict was added func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { // Prepare the SQL statement query := ` - INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at) - VALUES (:id, :name, :msgs, :created_at, :updated_at) + INSERT INTO chats (id, name, msgs, agent, created_at, updated_at) + VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at) + ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs, + updated_at=excluded.updated_at RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { return nil, err } + defer stmt.Close() // Execute the query and scan the result into a new chat object var resp models.Chat err = stmt.Get(&resp, chat) @@ -61,13 +90,26 @@ func (p ProviderSQL) RemoveChat(id uint32) error { return err } -func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory { +func (p ProviderSQL) ChatGetMaxID() (uint32, error) { + query := "SELECT MAX(id) FROM chats;" + var id uint32 + err := p.db.Get(&id, query) + return id, err +} + +// opens two connections +func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { - panic(err) + logger.Error("failed to open db connection", "error", err) + return nil } - // get SQLite version p := ProviderSQL{db: db, logger: logger} + p.s3Conn, err = sqlite3.Open(dbPath) + if err != nil { + logger.Error("failed to open vecdb connection", "error", err) + return nil + } p.Migrate() return p } diff --git a/storage/storage_test.go b/storage/storage_test.go index ad1f1bf..a1c4cf4 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,7 +1,7 @@ package storage import ( - "elefant/models" + "gf-lt/models" "fmt" "log/slog" "os" @@ -35,22 +35,27 @@ CREATE TABLE IF NOT EXISTS memories ( logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), } // Create a sample memory for testing - sampleMemory := &models.Memory{ + sampleMemory := models.Memory{ Agent: "testAgent", Topic: "testTopic", Mind: "testMind", CreatedAt: time.Now(), UpdatedAt: time.Now(), } + sampleMemoryRewrite := models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "same topic, new mind", + } cases := []struct { - memory *models.Memory + memories []models.Memory }{ - {memory: sampleMemory}, + {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}}, } for i, tc := range cases { t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { // Recall topics: get no rows - topics, err := provider.RecallTopics(tc.memory.Agent) + topics, err := provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -58,12 +63,12 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected no topics, got: %v", topics) } // Memorise - _, err = provider.Memorise(tc.memory) + _, err = provider.Memorise(&tc.memories[0]) if err != nil { t.Fatalf("Failed to memorise: %v", err) } // Recall topics: has topics - topics, err = provider.RecallTopics(tc.memory.Agent) + topics, err = provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -71,12 +76,20 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected topics, got none") } // Recall - content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic) if err != nil { t.Fatalf("Failed to recall: %v", err) } - if content != tc.memory.Mind { - t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + if content != tc.memories[0].Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content) + } + // rewrite mind of same agent-topic + newMem, err := provider.Memorise(&tc.memories[1]) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + if newMem.Mind == tc.memories[0].Mind { + t.Fatalf("Failed to change mind: %v", newMem.Mind) } }) } @@ -95,6 +108,7 @@ func TestChatHistory(t *testing.T) { id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`) @@ -159,3 +173,88 @@ func TestChatHistory(t *testing.T) { t.Errorf("Expected 0 chats, got %d", len(chats)) } } + +// func TestVecTable(t *testing.T) { +// // healthcheck +// db, err := sqlite3.Open(":memory:") +// if err != nil { +// t.Fatal(err) +// } +// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) +// if err != nil { +// t.Fatal(err) +// } +// stmt.Step() +// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) +// stmt.Close() +// // migration +// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") +// if err != nil { +// t.Fatal(err) +// } +// // data prep and insert +// items := map[int][]float32{ +// 1: {0.1, 0.1, 0.1, 0.1}, +// 2: {0.2, 0.2, 0.2, 0.2}, +// 3: {0.3, 0.3, 0.3, 0.3}, +// 4: {0.4, 0.4, 0.4, 0.4}, +// 5: {0.5, 0.5, 0.5, 0.5}, +// } +// q := []float32{0.4, 0.3, 0.3, 0.3} +// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") +// if err != nil { +// t.Fatal(err) +// } +// for id, values := range items { +// v, err := sqlite_vec.SerializeFloat32(values) +// if err != nil { +// t.Fatal(err) +// } +// stmt.BindInt(1, id) +// stmt.BindBlob(2, v) +// stmt.BindText(3, "some_chat") +// err = stmt.Exec() +// if err != nil { +// t.Fatal(err) +// } +// stmt.Reset() +// } +// stmt.Close() +// // select | vec search +// stmt, _, err = db.Prepare(` +// SELECT +// rowid, +// distance, +// embedding +// FROM vec_items +// WHERE embedding MATCH ? +// ORDER BY distance +// LIMIT 3 +// `) +// if err != nil { +// t.Fatal(err) +// } +// query, err := sqlite_vec.SerializeFloat32(q) +// if err != nil { +// t.Fatal(err) +// } +// stmt.BindBlob(1, query) +// for stmt.Step() { +// rowid := stmt.ColumnInt64(0) +// distance := stmt.ColumnFloat(1) +// emb := stmt.ColumnRawText(2) +// floats := decodeUnsafe(emb) +// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) +// } +// if err := stmt.Err(); err != nil { +// t.Fatal(err) +// } +// err = stmt.Close() +// if err != nil { +// t.Fatal(err) +// } +// err = db.Close() +// if err != nil { +// t.Fatal(err) +// } +// } diff --git a/storage/vector.go b/storage/vector.go new file mode 100644 index 0000000..71005e4 --- /dev/null +++ b/storage/vector.go @@ -0,0 +1,163 @@ +package storage + +import ( + "gf-lt/models" + "errors" + "fmt" + "unsafe" + + sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" +) + +type VectorRepo interface { + WriteVector(*models.VectorRow) error + SearchClosest(q []float32) ([]models.VectorRow, error) + ListFiles() ([]string, error) + RemoveEmbByFileName(filename string) error +} + +var ( + vecTableName5120 = "embeddings_5120" + vecTableName384 = "embeddings_384" +) + +func fetchTableName(emb []float32) (string, error) { + switch len(emb) { + case 5120: + return vecTableName5120, nil + case 384: + return vecTableName384, nil + default: + return "", fmt.Errorf("no table for the size of %d", len(emb)) + } +} + +func (p ProviderSQL) WriteVector(row *models.VectorRow) error { + tableName, err := fetchTableName(row.Embeddings) + if err != nil { + return err + } + stmt, _, err := p.s3Conn.Prepare( + fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName)) + if err != nil { + p.logger.Error("failed to prep a stmt", "error", err) + return err + } + defer stmt.Close() + v, err := sqlite_vec.SerializeFloat32(row.Embeddings) + if err != nil { + p.logger.Error("failed to serialize vector", + "emb-len", len(row.Embeddings), "error", err) + return err + } + if v == nil { + err = errors.New("empty vector after serialization") + p.logger.Error("empty vector after serialization", + "emb-len", len(row.Embeddings), "text", row.RawText, "error", err) + return err + } + if err := stmt.BindBlob(1, v); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(2, row.Slug); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(3, row.RawText); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + if err := stmt.BindText(4, row.FileName); err != nil { + p.logger.Error("failed to bind", "error", err) + return err + } + err = stmt.Exec() + if err != nil { + return err + } + return nil +} + +func decodeUnsafe(bs []byte) []float32 { + return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4) +} + +func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) { + tableName, err := fetchTableName(q) + if err != nil { + return nil, err + } + stmt, _, err := p.s3Conn.Prepare( + fmt.Sprintf(`SELECT + distance, + embedding, + slug, + raw_text, + filename + FROM %s + WHERE embedding MATCH ? + ORDER BY distance + LIMIT 3 + `, tableName)) + if err != nil { + return nil, err + } + query, err := sqlite_vec.SerializeFloat32(q[:]) + if err != nil { + return nil, err + } + if err := stmt.BindBlob(1, query); err != nil { + p.logger.Error("failed to bind", "error", err) + return nil, err + } + resp := []models.VectorRow{} + for stmt.Step() { + res := models.VectorRow{} + res.Distance = float32(stmt.ColumnFloat(0)) + emb := stmt.ColumnRawText(1) + res.Embeddings = decodeUnsafe(emb) + res.Slug = stmt.ColumnText(2) + res.RawText = stmt.ColumnText(3) + res.FileName = stmt.ColumnText(4) + resp = append(resp, res) + } + if err := stmt.Err(); err != nil { + return nil, err + } + err = stmt.Close() + if err != nil { + return nil, err + } + return resp, nil +} + +func (p ProviderSQL) ListFiles() ([]string, error) { + q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384) + stmt, _, err := p.s3Conn.Prepare(q) + if err != nil { + return nil, err + } + defer stmt.Close() + resp := []string{} + for stmt.Step() { + resp = append(resp, stmt.ColumnText(0)) + } + if err := stmt.Err(); err != nil { + return nil, err + } + return resp, nil +} + +func (p ProviderSQL) RemoveEmbByFileName(filename string) error { + q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384) + stmt, _, err := p.s3Conn.Prepare(q) + if err != nil { + return err + } + defer stmt.Close() + if err := stmt.BindText(1, filename); err != nil { + return err + } + return stmt.Exec() +} |