From 461d19aa2512fea7ac07e50c3178609850ef07c3 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Tue, 31 Dec 2024 13:25:13 +0300 Subject: Feat: add rag [wip; skip-ci] --- storage/storage_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) (limited to 'storage/storage_test.go') diff --git a/storage/storage_test.go b/storage/storage_test.go index 8373ab0..f6af4f5 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -3,13 +3,16 @@ package storage import ( "elefant/models" "fmt" + "log" "log/slog" "os" "testing" "time" + sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/ncruces" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" + "github.com/ncruces/go-sqlite3" ) func TestMemories(t *testing.T) { @@ -160,3 +163,88 @@ func TestChatHistory(t *testing.T) { t.Errorf("Expected 0 chats, got %d", len(chats)) } } + +func TestVecTable(t *testing.T) { + // healthcheck + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) + if err != nil { + t.Fatal(err) + } + stmt.Step() + log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) + stmt.Close() + // migration + err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") + if err != nil { + t.Fatal(err) + } + // data prep and insert + items := map[int][]float32{ + 1: {0.1, 0.1, 0.1, 0.1}, + 2: {0.2, 0.2, 0.2, 0.2}, + 3: {0.3, 0.3, 0.3, 0.3}, + 4: {0.4, 0.4, 0.4, 0.4}, + 5: {0.5, 0.5, 0.5, 0.5}, + } + q := []float32{0.28, 0.3, 0.3, 0.3} + stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") + if err != nil { + t.Fatal(err) + } + for id, values := range items { + v, err := sqlite_vec.SerializeFloat32(values) + if err != nil { + t.Fatal(err) + } + stmt.BindInt(1, id) + stmt.BindBlob(2, v) + stmt.BindText(3, "some_chat") + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + stmt.Reset() + } + stmt.Close() + // select | vec search + stmt, _, err = db.Prepare(` + SELECT + rowid, + distance, + embedding + FROM vec_items + WHERE embedding MATCH ? + ORDER BY distance + LIMIT 3 + `) + if err != nil { + t.Fatal(err) + } + query, err := sqlite_vec.SerializeFloat32(q) + if err != nil { + t.Fatal(err) + } + stmt.BindBlob(1, query) + for stmt.Step() { + rowid := stmt.ColumnInt64(0) + distance := stmt.ColumnFloat(1) + emb := stmt.ColumnRawText(2) + floats := decodeUnsafe(emb) + log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) + } + if err := stmt.Err(); err != nil { + t.Fatal(err) + } + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + err = db.Close() + if err != nil { + t.Fatal(err) + } +} -- cgit v1.2.3