summaryrefslogtreecommitdiff
path: root/storage
diff options
context:
space:
mode:
Diffstat (limited to 'storage')
-rw-r--r--storage/memory.go14
-rw-r--r--storage/migrate.go7
-rw-r--r--storage/migrations/001_init.up.sql1
-rw-r--r--storage/migrations/002_add_vector.up.sql12
-rw-r--r--storage/storage.go54
-rw-r--r--storage/storage_test.go119
-rw-r--r--storage/vector.go163
7 files changed, 350 insertions, 20 deletions
diff --git a/storage/memory.go b/storage/memory.go
index a7bf8cc..406182f 100644
--- a/storage/memory.go
+++ b/storage/memory.go
@@ -1,6 +1,6 @@
package storage
-import "elefant/models"
+import "gf-lt/models"
type Memories interface {
Memorise(m *models.Memory) (*models.Memory, error)
@@ -9,15 +9,23 @@ type Memories interface {
}
func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) {
- query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;"
+ query := `
+ INSERT INTO memories (agent, topic, mind)
+ VALUES (:agent, :topic, :mind)
+ ON CONFLICT (agent, topic) DO UPDATE
+ SET mind = excluded.mind,
+ updated_at = CURRENT_TIMESTAMP
+ RETURNING *;`
stmt, err := p.db.PrepareNamed(query)
if err != nil {
+ p.logger.Error("failed to prepare stmt", "query", query, "error", err)
return nil, err
}
defer stmt.Close()
var memory models.Memory
err = stmt.Get(&memory, m)
if err != nil {
+ p.logger.Error("failed to upsert memory", "query", query, "error", err)
return nil, err
}
return &memory, nil
@@ -28,6 +36,7 @@ func (p ProviderSQL) Recall(agent, topic string) (string, error) {
var mind string
err := p.db.Get(&mind, query, agent, topic)
if err != nil {
+ p.logger.Error("failed to get memory", "query", query, "error", err)
return "", err
}
return mind, nil
@@ -38,6 +47,7 @@ func (p ProviderSQL) RecallTopics(agent string) ([]string, error) {
var topics []string
err := p.db.Select(&topics, query, agent)
if err != nil {
+ p.logger.Error("failed to get topics", "query", query, "error", err)
return nil, err
}
return topics, nil
diff --git a/storage/migrate.go b/storage/migrate.go
index d97b99d..b05dddc 100644
--- a/storage/migrate.go
+++ b/storage/migrate.go
@@ -5,6 +5,8 @@ import (
"fmt"
"io/fs"
"strings"
+
+ _ "github.com/asg017/sqlite-vec-go-bindings/ncruces"
)
//go:embed migrations/*
@@ -27,10 +29,11 @@ func (p *ProviderSQL) Migrate() {
err := p.executeMigration(migrationsDir, file.Name())
if err != nil {
p.logger.Error("Failed to execute migration %s: %v", file.Name(), err)
+ panic(err)
}
}
}
- p.logger.Info("All migrations executed successfully!")
+ p.logger.Debug("All migrations executed successfully!")
}
func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error {
@@ -51,7 +54,7 @@ func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) err
func (p *ProviderSQL) executeSQL(sqlContent []byte) error {
// Connect to the database (example using a simple connection)
- _, err := p.db.Exec(string(sqlContent))
+ err := p.s3Conn.Exec(string(sqlContent))
if err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql
index 8980ccf..09bb5e6 100644
--- a/storage/migrations/001_init.up.sql
+++ b/storage/migrations/001_init.up.sql
@@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS chats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
msgs TEXT NOT NULL,
+ agent TEXT NOT NULL DEFAULT 'assistant',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
diff --git a/storage/migrations/002_add_vector.up.sql b/storage/migrations/002_add_vector.up.sql
new file mode 100644
index 0000000..2ac4621
--- /dev/null
+++ b/storage/migrations/002_add_vector.up.sql
@@ -0,0 +1,12 @@
+--CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_5120 USING vec0(
+-- embedding FLOAT[5120],
+-- slug TEXT NOT NULL,
+-- raw_text TEXT PRIMARY KEY,
+--);
+
+CREATE VIRTUAL TABLE IF NOT EXISTS embeddings_384 USING vec0(
+ embedding FLOAT[384],
+ slug TEXT NOT NULL,
+ raw_text TEXT PRIMARY KEY,
+ filename TEXT NOT NULL DEFAULT ''
+);
diff --git a/storage/storage.go b/storage/storage.go
index 67b8dd8..7911e13 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -1,23 +1,34 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"log/slog"
_ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx"
+ "github.com/ncruces/go-sqlite3"
)
+type FullRepo interface {
+ ChatHistory
+ Memories
+ VectorRepo
+}
+
type ChatHistory interface {
ListChats() ([]models.Chat, error)
GetChatByID(id uint32) (*models.Chat, error)
+ GetChatByChar(char string) ([]models.Chat, error)
GetLastChat() (*models.Chat, error)
+ GetLastChatByAgent(agent string) (*models.Chat, error)
UpsertChat(chat *models.Chat) (*models.Chat, error)
RemoveChat(id uint32) error
+ ChatGetMaxID() (uint32, error)
}
type ProviderSQL struct {
db *sqlx.DB
+ s3Conn *sqlite3.Conn
logger *slog.Logger
}
@@ -27,6 +38,12 @@ func (p ProviderSQL) ListChats() ([]models.Chat, error) {
return resp, err
}
+func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) {
+ resp := []models.Chat{}
+ err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char)
+ return resp, err
+}
+
func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) {
resp := models.Chat{}
err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id)
@@ -39,16 +56,28 @@ func (p ProviderSQL) GetLastChat() (*models.Chat, error) {
return &resp, err
}
+func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) {
+ resp := models.Chat{}
+ query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1"
+ err := p.db.Get(&resp, query, agent)
+ return &resp, err
+}
+
+// https://sqlite.org/lang_upsert.html
+// on conflict was added
func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) {
// Prepare the SQL statement
query := `
- INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at)
- VALUES (:id, :name, :msgs, :created_at, :updated_at)
+ INSERT INTO chats (id, name, msgs, agent, created_at, updated_at)
+ VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at)
+ ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs,
+ updated_at=excluded.updated_at
RETURNING *;`
stmt, err := p.db.PrepareNamed(query)
if err != nil {
return nil, err
}
+ defer stmt.Close()
// Execute the query and scan the result into a new chat object
var resp models.Chat
err = stmt.Get(&resp, chat)
@@ -61,13 +90,26 @@ func (p ProviderSQL) RemoveChat(id uint32) error {
return err
}
-func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory {
+func (p ProviderSQL) ChatGetMaxID() (uint32, error) {
+ query := "SELECT MAX(id) FROM chats;"
+ var id uint32
+ err := p.db.Get(&id, query)
+ return id, err
+}
+
+// opens two connections
+func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
db, err := sqlx.Open("sqlite", dbPath)
if err != nil {
- panic(err)
+ logger.Error("failed to open db connection", "error", err)
+ return nil
}
- // get SQLite version
p := ProviderSQL{db: db, logger: logger}
+ p.s3Conn, err = sqlite3.Open(dbPath)
+ if err != nil {
+ logger.Error("failed to open vecdb connection", "error", err)
+ return nil
+ }
p.Migrate()
return p
}
diff --git a/storage/storage_test.go b/storage/storage_test.go
index ad1f1bf..a1c4cf4 100644
--- a/storage/storage_test.go
+++ b/storage/storage_test.go
@@ -1,7 +1,7 @@
package storage
import (
- "elefant/models"
+ "gf-lt/models"
"fmt"
"log/slog"
"os"
@@ -35,22 +35,27 @@ CREATE TABLE IF NOT EXISTS memories (
logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)),
}
// Create a sample memory for testing
- sampleMemory := &models.Memory{
+ sampleMemory := models.Memory{
Agent: "testAgent",
Topic: "testTopic",
Mind: "testMind",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
+ sampleMemoryRewrite := models.Memory{
+ Agent: "testAgent",
+ Topic: "testTopic",
+ Mind: "same topic, new mind",
+ }
cases := []struct {
- memory *models.Memory
+ memories []models.Memory
}{
- {memory: sampleMemory},
+ {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) {
// Recall topics: get no rows
- topics, err := provider.RecallTopics(tc.memory.Agent)
+ topics, err := provider.RecallTopics(tc.memories[0].Agent)
if err != nil {
t.Fatalf("Failed to recall topics: %v", err)
}
@@ -58,12 +63,12 @@ CREATE TABLE IF NOT EXISTS memories (
t.Fatalf("Expected no topics, got: %v", topics)
}
// Memorise
- _, err = provider.Memorise(tc.memory)
+ _, err = provider.Memorise(&tc.memories[0])
if err != nil {
t.Fatalf("Failed to memorise: %v", err)
}
// Recall topics: has topics
- topics, err = provider.RecallTopics(tc.memory.Agent)
+ topics, err = provider.RecallTopics(tc.memories[0].Agent)
if err != nil {
t.Fatalf("Failed to recall topics: %v", err)
}
@@ -71,12 +76,20 @@ CREATE TABLE IF NOT EXISTS memories (
t.Fatalf("Expected topics, got none")
}
// Recall
- content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic)
+ content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic)
if err != nil {
t.Fatalf("Failed to recall: %v", err)
}
- if content != tc.memory.Mind {
- t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content)
+ if content != tc.memories[0].Mind {
+ t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content)
+ }
+ // rewrite mind of same agent-topic
+ newMem, err := provider.Memorise(&tc.memories[1])
+ if err != nil {
+ t.Fatalf("Failed to memorise: %v", err)
+ }
+ if newMem.Mind == tc.memories[0].Mind {
+ t.Fatalf("Failed to change mind: %v", newMem.Mind)
}
})
}
@@ -95,6 +108,7 @@ func TestChatHistory(t *testing.T) {
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
msgs TEXT NOT NULL,
+ agent TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);`)
@@ -159,3 +173,88 @@ func TestChatHistory(t *testing.T) {
t.Errorf("Expected 0 chats, got %d", len(chats))
}
}
+
+// func TestVecTable(t *testing.T) {
+// // healthcheck
+// db, err := sqlite3.Open(":memory:")
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Step()
+// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1))
+// stmt.Close()
+// // migration
+// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// // data prep and insert
+// items := map[int][]float32{
+// 1: {0.1, 0.1, 0.1, 0.1},
+// 2: {0.2, 0.2, 0.2, 0.2},
+// 3: {0.3, 0.3, 0.3, 0.3},
+// 4: {0.4, 0.4, 0.4, 0.4},
+// 5: {0.5, 0.5, 0.5, 0.5},
+// }
+// q := []float32{0.4, 0.3, 0.3, 0.3}
+// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)")
+// if err != nil {
+// t.Fatal(err)
+// }
+// for id, values := range items {
+// v, err := sqlite_vec.SerializeFloat32(values)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindInt(1, id)
+// stmt.BindBlob(2, v)
+// stmt.BindText(3, "some_chat")
+// err = stmt.Exec()
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.Reset()
+// }
+// stmt.Close()
+// // select | vec search
+// stmt, _, err = db.Prepare(`
+// SELECT
+// rowid,
+// distance,
+// embedding
+// FROM vec_items
+// WHERE embedding MATCH ?
+// ORDER BY distance
+// LIMIT 3
+// `)
+// if err != nil {
+// t.Fatal(err)
+// }
+// query, err := sqlite_vec.SerializeFloat32(q)
+// if err != nil {
+// t.Fatal(err)
+// }
+// stmt.BindBlob(1, query)
+// for stmt.Step() {
+// rowid := stmt.ColumnInt64(0)
+// distance := stmt.ColumnFloat(1)
+// emb := stmt.ColumnRawText(2)
+// floats := decodeUnsafe(emb)
+// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats)
+// }
+// if err := stmt.Err(); err != nil {
+// t.Fatal(err)
+// }
+// err = stmt.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// err = db.Close()
+// if err != nil {
+// t.Fatal(err)
+// }
+// }
diff --git a/storage/vector.go b/storage/vector.go
new file mode 100644
index 0000000..71005e4
--- /dev/null
+++ b/storage/vector.go
@@ -0,0 +1,163 @@
+package storage
+
+import (
+ "gf-lt/models"
+ "errors"
+ "fmt"
+ "unsafe"
+
+ sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces"
+)
+
+type VectorRepo interface {
+ WriteVector(*models.VectorRow) error
+ SearchClosest(q []float32) ([]models.VectorRow, error)
+ ListFiles() ([]string, error)
+ RemoveEmbByFileName(filename string) error
+}
+
+var (
+ vecTableName5120 = "embeddings_5120"
+ vecTableName384 = "embeddings_384"
+)
+
+func fetchTableName(emb []float32) (string, error) {
+ switch len(emb) {
+ case 5120:
+ return vecTableName5120, nil
+ case 384:
+ return vecTableName384, nil
+ default:
+ return "", fmt.Errorf("no table for the size of %d", len(emb))
+ }
+}
+
+func (p ProviderSQL) WriteVector(row *models.VectorRow) error {
+ tableName, err := fetchTableName(row.Embeddings)
+ if err != nil {
+ return err
+ }
+ stmt, _, err := p.s3Conn.Prepare(
+ fmt.Sprintf("INSERT INTO %s(embedding, slug, raw_text, filename) VALUES (?, ?, ?, ?)", tableName))
+ if err != nil {
+ p.logger.Error("failed to prep a stmt", "error", err)
+ return err
+ }
+ defer stmt.Close()
+ v, err := sqlite_vec.SerializeFloat32(row.Embeddings)
+ if err != nil {
+ p.logger.Error("failed to serialize vector",
+ "emb-len", len(row.Embeddings), "error", err)
+ return err
+ }
+ if v == nil {
+ err = errors.New("empty vector after serialization")
+ p.logger.Error("empty vector after serialization",
+ "emb-len", len(row.Embeddings), "text", row.RawText, "error", err)
+ return err
+ }
+ if err := stmt.BindBlob(1, v); err != nil {
+ p.logger.Error("failed to bind", "error", err)
+ return err
+ }
+ if err := stmt.BindText(2, row.Slug); err != nil {
+ p.logger.Error("failed to bind", "error", err)
+ return err
+ }
+ if err := stmt.BindText(3, row.RawText); err != nil {
+ p.logger.Error("failed to bind", "error", err)
+ return err
+ }
+ if err := stmt.BindText(4, row.FileName); err != nil {
+ p.logger.Error("failed to bind", "error", err)
+ return err
+ }
+ err = stmt.Exec()
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func decodeUnsafe(bs []byte) []float32 {
+ return unsafe.Slice((*float32)(unsafe.Pointer(&bs[0])), len(bs)/4)
+}
+
+func (p ProviderSQL) SearchClosest(q []float32) ([]models.VectorRow, error) {
+ tableName, err := fetchTableName(q)
+ if err != nil {
+ return nil, err
+ }
+ stmt, _, err := p.s3Conn.Prepare(
+ fmt.Sprintf(`SELECT
+ distance,
+ embedding,
+ slug,
+ raw_text,
+ filename
+ FROM %s
+ WHERE embedding MATCH ?
+ ORDER BY distance
+ LIMIT 3
+ `, tableName))
+ if err != nil {
+ return nil, err
+ }
+ query, err := sqlite_vec.SerializeFloat32(q[:])
+ if err != nil {
+ return nil, err
+ }
+ if err := stmt.BindBlob(1, query); err != nil {
+ p.logger.Error("failed to bind", "error", err)
+ return nil, err
+ }
+ resp := []models.VectorRow{}
+ for stmt.Step() {
+ res := models.VectorRow{}
+ res.Distance = float32(stmt.ColumnFloat(0))
+ emb := stmt.ColumnRawText(1)
+ res.Embeddings = decodeUnsafe(emb)
+ res.Slug = stmt.ColumnText(2)
+ res.RawText = stmt.ColumnText(3)
+ res.FileName = stmt.ColumnText(4)
+ resp = append(resp, res)
+ }
+ if err := stmt.Err(); err != nil {
+ return nil, err
+ }
+ err = stmt.Close()
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+func (p ProviderSQL) ListFiles() ([]string, error) {
+ q := fmt.Sprintf("SELECT filename FROM %s GROUP BY filename", vecTableName384)
+ stmt, _, err := p.s3Conn.Prepare(q)
+ if err != nil {
+ return nil, err
+ }
+ defer stmt.Close()
+ resp := []string{}
+ for stmt.Step() {
+ resp = append(resp, stmt.ColumnText(0))
+ }
+ if err := stmt.Err(); err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+func (p ProviderSQL) RemoveEmbByFileName(filename string) error {
+ q := fmt.Sprintf("DELETE FROM %s WHERE filename = ?", vecTableName384)
+ stmt, _, err := p.s3Conn.Prepare(q)
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
+ if err := stmt.BindText(1, filename); err != nil {
+ return err
+ }
+ return stmt.Exec()
+}