diff options
Diffstat (limited to 'storage')
-rw-r--r-- | storage/memory.go | 2 | ||||
-rw-r--r-- | storage/storage.go | 2 | ||||
-rw-r--r-- | storage/storage_test.go | 173 | ||||
-rw-r--r-- | storage/vector.go | 2 |
4 files changed, 88 insertions, 91 deletions
diff --git a/storage/memory.go b/storage/memory.go index c9fc853..406182f 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -1,6 +1,6 @@ package storage -import "elefant/models" +import "gf-lt/models" type Memories interface { Memorise(m *models.Memory) (*models.Memory, error) diff --git a/storage/storage.go b/storage/storage.go index f759700..7911e13 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,7 +1,7 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" diff --git a/storage/storage_test.go b/storage/storage_test.go index ff3b5e6..a1c4cf4 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,18 +1,15 @@ package storage import ( - "elefant/models" + "gf-lt/models" "fmt" - "log" "log/slog" "os" "testing" "time" - sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" - "github.com/ncruces/go-sqlite3" ) func TestMemories(t *testing.T) { @@ -177,87 +174,87 @@ func TestChatHistory(t *testing.T) { } } -func TestVecTable(t *testing.T) { - // healthcheck - db, err := sqlite3.Open(":memory:") - if err != nil { - t.Fatal(err) - } - stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) - if err != nil { - t.Fatal(err) - } - stmt.Step() - log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) - stmt.Close() - // migration - err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") - if err != nil { - t.Fatal(err) - } - // data prep and insert - items := map[int][]float32{ - 1: {0.1, 0.1, 0.1, 0.1}, - 2: {0.2, 0.2, 0.2, 0.2}, - 3: {0.3, 0.3, 0.3, 0.3}, - 4: {0.4, 0.4, 0.4, 0.4}, - 5: {0.5, 0.5, 0.5, 0.5}, - } - q := []float32{0.28, 0.3, 0.3, 0.3} - stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") - if err != nil { - t.Fatal(err) - } - for id, values := range items { - v, err := sqlite_vec.SerializeFloat32(values) - if err != nil { - t.Fatal(err) - } - stmt.BindInt(1, id) - stmt.BindBlob(2, v) - stmt.BindText(3, "some_chat") - err = stmt.Exec() - if err != nil { - t.Fatal(err) - } - stmt.Reset() - } - stmt.Close() - // select | vec search - stmt, _, err = db.Prepare(` - SELECT - rowid, - distance, - embedding - FROM vec_items - WHERE embedding MATCH ? - ORDER BY distance - LIMIT 3 - `) - if err != nil { - t.Fatal(err) - } - query, err := sqlite_vec.SerializeFloat32(q) - if err != nil { - t.Fatal(err) - } - stmt.BindBlob(1, query) - for stmt.Step() { - rowid := stmt.ColumnInt64(0) - distance := stmt.ColumnFloat(1) - emb := stmt.ColumnRawText(2) - floats := decodeUnsafe(emb) - log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) - } - if err := stmt.Err(); err != nil { - t.Fatal(err) - } - err = stmt.Close() - if err != nil { - t.Fatal(err) - } - err = db.Close() - if err != nil { - t.Fatal(err) - } -} +// func TestVecTable(t *testing.T) { +// // healthcheck +// db, err := sqlite3.Open(":memory:") +// if err != nil { +// t.Fatal(err) +// } +// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) +// if err != nil { +// t.Fatal(err) +// } +// stmt.Step() +// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) +// stmt.Close() +// // migration +// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") +// if err != nil { +// t.Fatal(err) +// } +// // data prep and insert +// items := map[int][]float32{ +// 1: {0.1, 0.1, 0.1, 0.1}, +// 2: {0.2, 0.2, 0.2, 0.2}, +// 3: {0.3, 0.3, 0.3, 0.3}, +// 4: {0.4, 0.4, 0.4, 0.4}, +// 5: {0.5, 0.5, 0.5, 0.5}, +// } +// q := []float32{0.4, 0.3, 0.3, 0.3} +// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") +// if err != nil { +// t.Fatal(err) +// } +// for id, values := range items { +// v, err := sqlite_vec.SerializeFloat32(values) +// if err != nil { +// t.Fatal(err) +// } +// stmt.BindInt(1, id) +// stmt.BindBlob(2, v) +// stmt.BindText(3, "some_chat") +// err = stmt.Exec() +// if err != nil { +// t.Fatal(err) +// } +// stmt.Reset() +// } +// stmt.Close() +// // select | vec search +// stmt, _, err = db.Prepare(` +// SELECT +// rowid, +// distance, +// embedding +// FROM vec_items +// WHERE embedding MATCH ? +// ORDER BY distance +// LIMIT 3 +// `) +// if err != nil { +// t.Fatal(err) +// } +// query, err := sqlite_vec.SerializeFloat32(q) +// if err != nil { +// t.Fatal(err) +// } +// stmt.BindBlob(1, query) +// for stmt.Step() { +// rowid := stmt.ColumnInt64(0) +// distance := stmt.ColumnFloat(1) +// emb := stmt.ColumnRawText(2) +// floats := decodeUnsafe(emb) +// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) +// } +// if err := stmt.Err(); err != nil { +// t.Fatal(err) +// } +// err = stmt.Close() +// if err != nil { +// t.Fatal(err) +// } +// err = db.Close() +// if err != nil { +// t.Fatal(err) +// } +// } diff --git a/storage/vector.go b/storage/vector.go index 5e9069c..71005e4 100644 --- a/storage/vector.go +++ b/storage/vector.go @@ -1,7 +1,7 @@ package storage import ( - "elefant/models" + "gf-lt/models" "errors" "fmt" "unsafe" |