summaryrefslogtreecommitdiff
path: root/storage/storage.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2024-11-19 17:15:02 +0300
committerGrail Finder <wohilas@gmail.com>2024-11-19 17:15:02 +0300
commitf32375488f5127c910021f627d83e017c5c7a10f (patch)
tree5a697ca8a42e3651fa57ea346af223b67f3e3a3a /storage/storage.go
parentd3cc8774b13a0c8e9fbf11947b9caca216595a8d (diff)
Feat: add storage interface; add sqlite impl
Diffstat (limited to 'storage/storage.go')
-rw-r--r--storage/storage.go65
1 files changed, 65 insertions, 0 deletions
diff --git a/storage/storage.go b/storage/storage.go
new file mode 100644
index 0000000..11cbb4a
--- /dev/null
+++ b/storage/storage.go
@@ -0,0 +1,65 @@
+package storage
+
+import (
+ "elefant/models"
+ "fmt"
+
+ _ "github.com/glebarez/go-sqlite"
+ "github.com/jmoiron/sqlx"
+)
+
+type ChatHistory interface {
+ ListChats() ([]models.Chat, error)
+ GetChatByID(id uint32) (*models.Chat, error)
+ UpsertChat(chat *models.Chat) (*models.Chat, error)
+ RemoveChat(id uint32) error
+}
+
+type ProviderSQL struct {
+ db *sqlx.DB
+}
+
+func (p ProviderSQL) ListChats() ([]models.Chat, error) {
+ resp := []models.Chat{}
+ err := p.db.Select(&resp, "SELECT * FROM chat;")
+ return resp, err
+}
+
+func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) {
+ resp := models.Chat{}
+ err := p.db.Get(&resp, "SELECT * FROM chat WHERE id=$1;", id)
+ return &resp, err
+}
+
+func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) {
+ // Prepare the SQL statement
+ query := `
+ INSERT OR REPLACE INTO chat (id, name, msgs, created_at, updated_at)
+ VALUES (:id, :name, :msgs, :created_at, :updated_at)
+ RETURNING *;`
+ stmt, err := p.db.PrepareNamed(query)
+ if err != nil {
+ return nil, err
+ }
+ // Execute the query and scan the result into a new chat object
+ var resp models.Chat
+ err = stmt.Get(&resp, chat)
+ return &resp, err
+}
+
+func (p ProviderSQL) RemoveChat(id uint32) error {
+ query := "DELETE FROM chat WHERE ID = $1;"
+ _, err := p.db.Exec(query, id)
+ return err
+}
+
+func NewProviderSQL(dbPath string) ChatHistory {
+ db, err := sqlx.Open("sqlite", dbPath)
+ if err != nil {
+ panic(err)
+ }
+ // get SQLite version
+ res := db.QueryRow("select sqlite_version()")
+ fmt.Println(res)
+ return ProviderSQL{db: db}
+}