From 5ccad20bd680dc443b30f0decc8fca13427dc70d Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Wed, 20 Nov 2024 20:47:49 +0300 Subject: Feat: add memory [wip] --- storage/memory.go | 44 ++++++++++++++++++++++ storage/migrations/001_init.up.sql | 13 ++++++- storage/storage.go | 10 ++--- storage/storage_test.go | 77 +++++++++++++++++++++++++++++++++++++- 4 files changed, 135 insertions(+), 9 deletions(-) create mode 100644 storage/memory.go (limited to 'storage') diff --git a/storage/memory.go b/storage/memory.go new file mode 100644 index 0000000..a7bf8cc --- /dev/null +++ b/storage/memory.go @@ -0,0 +1,44 @@ +package storage + +import "elefant/models" + +type Memories interface { + Memorise(m *models.Memory) (*models.Memory, error) + Recall(agent, topic string) (string, error) + RecallTopics(agent string) ([]string, error) +} + +func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { + query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;" + stmt, err := p.db.PrepareNamed(query) + if err != nil { + return nil, err + } + defer stmt.Close() + var memory models.Memory + err = stmt.Get(&memory, m) + if err != nil { + return nil, err + } + return &memory, nil +} + +func (p ProviderSQL) Recall(agent, topic string) (string, error) { + query := "SELECT mind FROM memories WHERE agent = $1 AND topic = $2" + var mind string + err := p.db.Get(&mind, query, agent, topic) + if err != nil { + return "", err + } + return mind, nil +} + +func (p ProviderSQL) RecallTopics(agent string) ([]string, error) { + query := "SELECT DISTINCT topic FROM memories WHERE agent = $1" + var topics []string + err := p.db.Select(&topics, query, agent) + if err != nil { + return nil, err + } + return topics, nil +} diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql index 1b3e63d..8980ccf 100644 --- a/storage/migrations/001_init.up.sql +++ b/storage/migrations/001_init.up.sql @@ -1,7 +1,16 @@ -CREATE TABLE IF NOT EXISTS chat ( +CREATE TABLE IF NOT EXISTS chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, - msgs TEXT NOT NULL, -- Store messages as a comma-separated string + msgs TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); + +CREATE TABLE IF NOT EXISTS memories ( + agent TEXT NOT NULL, + topic TEXT NOT NULL, + mind TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (agent, topic) +); diff --git a/storage/storage.go b/storage/storage.go index edbd393..67b8dd8 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -23,26 +23,26 @@ type ProviderSQL struct { func (p ProviderSQL) ListChats() ([]models.Chat, error) { resp := []models.Chat{} - err := p.db.Select(&resp, "SELECT * FROM chat;") + err := p.db.Select(&resp, "SELECT * FROM chats;") return resp, err } func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { resp := models.Chat{} - err := p.db.Get(&resp, "SELECT * FROM chat WHERE id=$1;", id) + err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id) return &resp, err } func (p ProviderSQL) GetLastChat() (*models.Chat, error) { resp := models.Chat{} - err := p.db.Get(&resp, "SELECT * FROM chat ORDER BY updated_at DESC LIMIT 1") + err := p.db.Get(&resp, "SELECT * FROM chats ORDER BY updated_at DESC LIMIT 1") return &resp, err } func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { // Prepare the SQL statement query := ` - INSERT OR REPLACE INTO chat (id, name, msgs, created_at, updated_at) + INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at) VALUES (:id, :name, :msgs, :created_at, :updated_at) RETURNING *;` stmt, err := p.db.PrepareNamed(query) @@ -56,7 +56,7 @@ func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { } func (p ProviderSQL) RemoveChat(id uint32) error { - query := "DELETE FROM chat WHERE ID = $1;" + query := "DELETE FROM chats WHERE ID = $1;" _, err := p.db.Exec(query, id) return err } diff --git a/storage/storage_test.go b/storage/storage_test.go index 0bf1fd6..ad1f1bf 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -2,6 +2,9 @@ package storage import ( "elefant/models" + "fmt" + "log/slog" + "os" "testing" "time" @@ -9,6 +12,76 @@ import ( "github.com/jmoiron/sqlx" ) +func TestMemories(t *testing.T) { + db, err := sqlx.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open SQLite in-memory database: %v", err) + } + defer db.Close() + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS memories ( + agent TEXT NOT NULL, + topic TEXT NOT NULL, + mind TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (agent, topic) +);`) + if err != nil { + t.Fatalf("Failed to create chat table: %v", err) + } + provider := ProviderSQL{ + db: db, + logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), + } + // Create a sample memory for testing + sampleMemory := &models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "testMind", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + cases := []struct { + memory *models.Memory + }{ + {memory: sampleMemory}, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { + // Recall topics: get no rows + topics, err := provider.RecallTopics(tc.memory.Agent) + if err != nil { + t.Fatalf("Failed to recall topics: %v", err) + } + if len(topics) != 0 { + t.Fatalf("Expected no topics, got: %v", topics) + } + // Memorise + _, err = provider.Memorise(tc.memory) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + // Recall topics: has topics + topics, err = provider.RecallTopics(tc.memory.Agent) + if err != nil { + t.Fatalf("Failed to recall topics: %v", err) + } + if len(topics) == 0 { + t.Fatalf("Expected topics, got none") + } + // Recall + content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + if err != nil { + t.Fatalf("Failed to recall: %v", err) + } + if content != tc.memory.Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + } + }) + } +} + func TestChatHistory(t *testing.T) { // Create an in-memory SQLite database db, err := sqlx.Open("sqlite", ":memory:") @@ -18,10 +91,10 @@ func TestChatHistory(t *testing.T) { defer db.Close() // Create the chat table _, err = db.Exec(` - CREATE TABLE chat ( + CREATE TABLE chats ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, - msgs TEXT NOT NULL, -- Store messages as a comma-separated string + msgs TEXT NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP );`) -- cgit v1.2.3