diff options
Diffstat (limited to 'storage/storage.go')
-rw-r--r-- | storage/storage.go | 54 |
1 files changed, 48 insertions, 6 deletions
diff --git a/storage/storage.go b/storage/storage.go index 67b8dd8..7911e13 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,23 +1,34 @@ package storage import ( - "elefant/models" + "gf-lt/models" "log/slog" _ "github.com/glebarez/go-sqlite" "github.com/jmoiron/sqlx" + "github.com/ncruces/go-sqlite3" ) +type FullRepo interface { + ChatHistory + Memories + VectorRepo +} + type ChatHistory interface { ListChats() ([]models.Chat, error) GetChatByID(id uint32) (*models.Chat, error) + GetChatByChar(char string) ([]models.Chat, error) GetLastChat() (*models.Chat, error) + GetLastChatByAgent(agent string) (*models.Chat, error) UpsertChat(chat *models.Chat) (*models.Chat, error) RemoveChat(id uint32) error + ChatGetMaxID() (uint32, error) } type ProviderSQL struct { db *sqlx.DB + s3Conn *sqlite3.Conn logger *slog.Logger } @@ -27,6 +38,12 @@ func (p ProviderSQL) ListChats() ([]models.Chat, error) { return resp, err } +func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) { + resp := []models.Chat{} + err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char) + return resp, err +} + func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { resp := models.Chat{} err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id) @@ -39,16 +56,28 @@ func (p ProviderSQL) GetLastChat() (*models.Chat, error) { return &resp, err } +func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) { + resp := models.Chat{} + query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1" + err := p.db.Get(&resp, query, agent) + return &resp, err +} + +// https://sqlite.org/lang_upsert.html +// on conflict was added func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { // Prepare the SQL statement query := ` - INSERT OR REPLACE INTO chats (id, name, msgs, created_at, updated_at) - VALUES (:id, :name, :msgs, :created_at, :updated_at) + INSERT INTO chats (id, name, msgs, agent, created_at, updated_at) + VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at) + ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs, + updated_at=excluded.updated_at RETURNING *;` stmt, err := p.db.PrepareNamed(query) if err != nil { return nil, err } + defer stmt.Close() // Execute the query and scan the result into a new chat object var resp models.Chat err = stmt.Get(&resp, chat) @@ -61,13 +90,26 @@ func (p ProviderSQL) RemoveChat(id uint32) error { return err } -func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory { +func (p ProviderSQL) ChatGetMaxID() (uint32, error) { + query := "SELECT MAX(id) FROM chats;" + var id uint32 + err := p.db.Get(&id, query) + return id, err +} + +// opens two connections +func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo { db, err := sqlx.Open("sqlite", dbPath) if err != nil { - panic(err) + logger.Error("failed to open db connection", "error", err) + return nil } - // get SQLite version p := ProviderSQL{db: db, logger: logger} + p.s3Conn, err = sqlite3.Open(dbPath) + if err != nil { + logger.Error("failed to open vecdb connection", "error", err) + return nil + } p.Migrate() return p } |