From 5ccad20bd680dc443b30f0decc8fca13427dc70d Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Wed, 20 Nov 2024 20:47:49 +0300 Subject: Feat: add memory [wip] --- storage/storage_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 2 deletions(-) (limited to 'storage/storage_test.go') diff --git a/storage/storage_test.go b/storage/storage_test.go index 0bf1fd6..ad1f1bf 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -2,6 +2,9 @@ package storage import ( "elefant/models" + "fmt" + "log/slog" + "os" "testing" "time" @@ -9,6 +12,76 @@ import ( "github.com/jmoiron/sqlx" ) +func TestMemories(t *testing.T) { + db, err := sqlx.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open SQLite in-memory database: %v", err) + } + defer db.Close() + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS memories ( + agent TEXT NOT NULL, + topic TEXT NOT NULL, + mind TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (agent, topic) +);`) + if err != nil { + t.Fatalf("Failed to create chat table: %v", err) + } + provider := ProviderSQL{ + db: db, + logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), + } + // Create a sample memory for testing + sampleMemory := &models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "testMind", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + cases := []struct { + memory *models.Memory + }{ + {memory: sampleMemory}, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { + // Recall topics: get no rows + topics, err := provider.RecallTopics(tc.memory.Agent) + if err != nil { + t.Fatalf("Failed to recall topics: %v", err) + } + if len(topics) != 0 { + t.Fatalf("Expected no topics, got: %v", topics) + } + // Memorise + _, err = provider.Memorise(tc.memory) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + // Recall topics: has topics + topics, err = provider.RecallTopics(tc.memory.Agent) + if err != nil { + t.Fatalf("Failed to recall topics: %v", err) + } + if len(topics) == 0 { + t.Fatalf("Expected topics, got none") + } + // Recall + content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + if err != nil { + t.Fatalf("Failed to recall: %v", err) + } + if content != tc.memory.Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + } + }) + } +} + func TestChatHistory(t *testing.T) { // Create an in-memory SQLite database db, err := sqlx.Open("sqlite", ":memory:") @@ -18,10 +91,10 @@ func TestChatHistory(t *testing.T) { defer db.Close() // Create the chat table _, err = db.Exec(` - CREATE TABLE chat ( + CREATE TABLE chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, - msgs TEXT NOT NULL, -- Store messages as a comma-separated string + msgs TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`) -- cgit v1.2.3