diff options
Diffstat (limited to 'storage')
-rw-r--r-- | storage/memory.go | 10 | ||||
-rw-r--r-- | storage/storage_test.go | 31 |
2 files changed, 30 insertions, 11 deletions
diff --git a/storage/memory.go b/storage/memory.go index 088ce1c..c9fc853 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -9,7 +9,13 @@ type Memories interface { } func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { - query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;" + query := ` + INSERT INTO memories (agent, topic, mind) + VALUES (:agent, :topic, :mind) + ON CONFLICT (agent, topic) DO UPDATE + SET mind = excluded.mind, + updated_at = CURRENT_TIMESTAMP + RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { p.logger.Error("failed to prepare stmt", "query", query, "error", err) @@ -19,7 +25,7 @@ func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { var memory models.Memory err = stmt.Get(&memory, m) if err != nil { - p.logger.Error("failed to insert memory", "query", query, "error", err) + p.logger.Error("failed to upsert memory", "query", query, "error", err) return nil, err } return &memory, nil diff --git a/storage/storage_test.go b/storage/storage_test.go index f6af4f5..ff3b5e6 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -38,22 +38,27 @@ CREATE TABLE IF NOT EXISTS memories ( logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), } // Create a sample memory for testing - sampleMemory := &models.Memory{ + sampleMemory := models.Memory{ Agent: "testAgent", Topic: "testTopic", Mind: "testMind", CreatedAt: time.Now(), UpdatedAt: time.Now(), } + sampleMemoryRewrite := models.Memory{ + Agent: "testAgent", + Topic: "testTopic", + Mind: "same topic, new mind", + } cases := []struct { - memory *models.Memory + memories []models.Memory }{ - {memory: sampleMemory}, + {memories: []models.Memory{sampleMemory, sampleMemoryRewrite}}, } for i, tc := range cases { t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) { // Recall topics: get no rows - topics, err := provider.RecallTopics(tc.memory.Agent) + topics, err := provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -61,12 +66,12 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected no topics, got: %v", topics) } // Memorise - _, err = provider.Memorise(tc.memory) + _, err = provider.Memorise(&tc.memories[0]) if err != nil { t.Fatalf("Failed to memorise: %v", err) } // Recall topics: has topics - topics, err = provider.RecallTopics(tc.memory.Agent) + topics, err = provider.RecallTopics(tc.memories[0].Agent) if err != nil { t.Fatalf("Failed to recall topics: %v", err) } @@ -74,12 +79,20 @@ CREATE TABLE IF NOT EXISTS memories ( t.Fatalf("Expected topics, got none") } // Recall - content, err := provider.Recall(tc.memory.Agent, tc.memory.Topic) + content, err := provider.Recall(tc.memories[0].Agent, tc.memories[0].Topic) if err != nil { t.Fatalf("Failed to recall: %v", err) } - if content != tc.memory.Mind { - t.Fatalf("Expected content: %v, got: %v", tc.memory.Mind, content) + if content != tc.memories[0].Mind { + t.Fatalf("Expected content: %v, got: %v", tc.memories[0].Mind, content) + } + // rewrite mind of same agent-topic + newMem, err := provider.Memorise(&tc.memories[1]) + if err != nil { + t.Fatalf("Failed to memorise: %v", err) + } + if newMem.Mind == tc.memories[0].Mind { + t.Fatalf("Failed to change mind: %v", newMem.Mind) } }) } |