diff options
Diffstat (limited to 'storage/storage_test.go')
-rw-r--r-- | storage/storage_test.go | 119 |
1 files changed, 109 insertions, 10 deletions
diff --git a/storage/storage_test.go b/storage/storage_test.go index ad1f1bf..a1c4cf4 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -1,7 +1,7 @@ package storage import ( - "elefant/models" + "gf-lt/models" "fmt" "log/slog" "os" @@ -35,22 +35,27 @@ CREATE TABLE IF NOT EXISTS memories ( logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), } // Create a sample memory for testing - sampleMemory := &models.Memory{ + sampleMemory := models.Memory{ Agent: "testAgent", Topic: "testTopic", Mind: "testMind", CreatedAt: time.Now(), UpdatedAt: time.Now(), } + sampleMemoryRewrite := models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "same topic, new mind", + } cases := []struct { - memory *models.Memory + memories []models.Memory }{ - {memory: sampleMemory}, + {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}}, } for i, tc := range cases { t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { // Recall topics: get no rows - topics, err := provider.RecallTopics(tc.memory.Agent) + topics, err := provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -58,12 +63,12 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected no topics, got: %v", topics) } // Memorise - _, err = provider.Memorise(tc.memory) + _, err = provider.Memorise(&tc.memories[0]) if err != nil { t.Fatalf("Failed to memorise: %v", err) } // Recall topics: has topics - topics, err = provider.RecallTopics(tc.memory.Agent) + topics, err = provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -71,12 +76,20 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected topics, got none") } // Recall - content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic) if err != nil { t.Fatalf("Failed to recall: %v", err) } - if content != tc.memory.Mind { - t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + if content != tc.memories[0].Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content) + } + // rewrite mind of same agent-topic + newMem, err := provider.Memorise(&tc.memories[1]) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + if newMem.Mind == tc.memories[0].Mind { + t.Fatalf("Failed to change mind: %v", newMem.Mind) } }) } @@ -95,6 +108,7 @@ func TestChatHistory(t *testing.T) { id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, msgs TEXT NOT NULL, + agent TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`) @@ -159,3 +173,88 @@ func TestChatHistory(t *testing.T) { t.Errorf("Expected 0 chats, got %d", len(chats)) } } + +// func TestVecTable(t *testing.T) { +// // healthcheck +// db, err := sqlite3.Open(":memory:") +// if err != nil { +// t.Fatal(err) +// } +// stmt, _, err := db.Prepare(`SELECT sqlite_version(), vec_version()`) +// if err != nil { +// t.Fatal(err) +// } +// stmt.Step() +// log.Printf("sqlite_version=%s, vec_version=%s\n", stmt.ColumnText(0), stmt.ColumnText(1)) +// stmt.Close() +// // migration +// err = db.Exec("CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[4], chat_name TEXT NOT NULL)") +// if err != nil { +// t.Fatal(err) +// } +// // data prep and insert +// items := map[int][]float32{ +// 1: {0.1, 0.1, 0.1, 0.1}, +// 2: {0.2, 0.2, 0.2, 0.2}, +// 3: {0.3, 0.3, 0.3, 0.3}, +// 4: {0.4, 0.4, 0.4, 0.4}, +// 5: {0.5, 0.5, 0.5, 0.5}, +// } +// q := []float32{0.4, 0.3, 0.3, 0.3} +// stmt, _, err = db.Prepare("INSERT INTO vec_items(rowid, embedding, chat_name) VALUES (?, ?, ?)") +// if err != nil { +// t.Fatal(err) +// } +// for id, values := range items { +// v, err := sqlite_vec.SerializeFloat32(values) +// if err != nil { +// t.Fatal(err) +// } +// stmt.BindInt(1, id) +// stmt.BindBlob(2, v) +// stmt.BindText(3, "some_chat") +// err = stmt.Exec() +// if err != nil { +// t.Fatal(err) +// } +// stmt.Reset() +// } +// stmt.Close() +// // select | vec search +// stmt, _, err = db.Prepare(` +// SELECT +// rowid, +// distance, +// embedding +// FROM vec_items +// WHERE embedding MATCH ? +// ORDER BY distance +// LIMIT 3 +// `) +// if err != nil { +// t.Fatal(err) +// } +// query, err := sqlite_vec.SerializeFloat32(q) +// if err != nil { +// t.Fatal(err) +// } +// stmt.BindBlob(1, query) +// for stmt.Step() { +// rowid := stmt.ColumnInt64(0) +// distance := stmt.ColumnFloat(1) +// emb := stmt.ColumnRawText(2) +// floats := decodeUnsafe(emb) +// log.Printf("rowid=%d, distance=%f, floats=%v\n", rowid, distance, floats) +// } +// if err := stmt.Err(); err != nil { +// t.Fatal(err) +// } +// err = stmt.Close() +// if err != nil { +// t.Fatal(err) +// } +// err = db.Close() +// if err != nil { +// t.Fatal(err) +// } +// } |