diff options
Diffstat (limited to 'storage')
-rw-r--r-- | storage/memory.go | 4 | ||||
-rw-r--r-- | storage/storage.go | 19 |
2 files changed, 20 insertions, 3 deletions
diff --git a/storage/memory.go b/storage/memory.go index a7bf8cc..088ce1c 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -12,12 +12,14 @@ func (p ProviderSQL) Memorise(m *models.Memory) (*models.Memory, error) { query := "INSERT INTO memories (agent, topic, mind) VALUES (:agent, :topic, :mind) RETURNING *;" stmt, err := p.db.PrepareNamed(query) if err != nil { + p.logger.Error("failed to prepare stmt", "query", query, "error", err) return nil, err } defer stmt.Close() var memory models.Memory err = stmt.Get(&memory, m) if err != nil { + p.logger.Error("failed to insert memory", "query", query, "error", err) return nil, err } return &memory, nil @@ -28,6 +30,7 @@ func (p ProviderSQL) Recall(agent, topic string) (string, error) { var mind string err := p.db.Get(&mind, query, agent, topic) if err != nil { + p.logger.Error("failed to get memory", "query", query, "error", err) return "", err } return mind, nil @@ -38,6 +41,7 @@ func (p ProviderSQL) RecallTopics(agent string) ([]string, error) { var topics []string err := p.db.Select(&topics, query, agent) if err != nil { + p.logger.Error("failed to get topics", "query", query, "error", err) return nil, err } return topics, nil diff --git a/storage/storage.go b/storage/storage.go index 67b8dd8..c863799 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -8,12 +8,18 @@ import ( "github.com/jmoiron/sqlx" ) +type FullRepo interface { + ChatHistory + Memories +} + type ChatHistory interface { ListChats() ([]models.Chat, error) GetChatByID(id uint32) (*models.Chat, error) GetLastChat() (*models.Chat, error) UpsertChat(chat *models.Chat) (*models.Chat, error) RemoveChat(id uint32) error + ChatGetMaxID() (uint32, error) } type ProviderSQL struct { @@ -61,12 +67,19 @@ func (p ProviderSQL) RemoveChat(id uint32) error { return err } -func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory { +func (p ProviderSQL) ChatGetMaxID() (uint32, error) { + query := "SELECT MAX(id) FROM chats;" + var id uint32 + err := p.db.Get(&id, query) + return id, err +} + +func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { - panic(err) + logger.Error("failed to open db connection", "error", err) + return nil } - // get SQLite version p := ProviderSQL{db: db, logger: logger} p.Migrate() return p |