summaryrefslogtreecommitdiff
path: root/storage
diff options
context:
space:
mode:
Diffstat (limited to 'storage')
-rw-r--r--storage/migrations/001_init.down.sql1
-rw-r--r--storage/migrations/001_init.up.sql7
-rw-r--r--storage/storage.go65
-rw-r--r--storage/storage_test.go88
4 files changed, 161 insertions, 0 deletions
diff --git a/storage/migrations/001_init.down.sql b/storage/migrations/001_init.down.sql
new file mode 100644
index 0000000..0ef183f
--- /dev/null
+++ b/storage/migrations/001_init.down.sql
@@ -0,0 +1 @@
+DROP TABLE IF EXISTS chat;
diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql
new file mode 100644
index 0000000..287f3d1
--- /dev/null
+++ b/storage/migrations/001_init.up.sql
@@ -0,0 +1,7 @@
+CREATE TABLE chat (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ msgs TEXT NOT NULL, -- Store messages as a comma-separated string
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
diff --git a/storage/storage.go b/storage/storage.go
new file mode 100644
index 0000000..11cbb4a
--- /dev/null
+++ b/storage/storage.go
@@ -0,0 +1,65 @@
+package storage
+
+import (
+ "elefant/models"
+ "fmt"
+
+ _ "github.com/glebarez/go-sqlite"
+ "github.com/jmoiron/sqlx"
+)
+
+type ChatHistory interface {
+ ListChats() ([]models.Chat, error)
+ GetChatByID(id uint32) (*models.Chat, error)
+ UpsertChat(chat *models.Chat) (*models.Chat, error)
+ RemoveChat(id uint32) error
+}
+
+type ProviderSQL struct {
+ db *sqlx.DB
+}
+
+func (p ProviderSQL) ListChats() ([]models.Chat, error) {
+ resp := []models.Chat{}
+ err := p.db.Select(&resp, "SELECT * FROM chat;")
+ return resp, err
+}
+
+func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) {
+ resp := models.Chat{}
+ err := p.db.Get(&resp, "SELECT * FROM chat WHERE id=$1;", id)
+ return &resp, err
+}
+
+func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) {
+ // Prepare the SQL statement
+ query := `
+ INSERT OR REPLACE INTO chat (id, name, msgs, created_at, updated_at)
+ VALUES (:id, :name, :msgs, :created_at, :updated_at)
+ RETURNING *;`
+ stmt, err := p.db.PrepareNamed(query)
+ if err != nil {
+ return nil, err
+ }
+ // Execute the query and scan the result into a new chat object
+ var resp models.Chat
+ err = stmt.Get(&resp, chat)
+ return &resp, err
+}
+
+func (p ProviderSQL) RemoveChat(id uint32) error {
+ query := "DELETE FROM chat WHERE ID = $1;"
+ _, err := p.db.Exec(query, id)
+ return err
+}
+
+func NewProviderSQL(dbPath string) ChatHistory {
+ db, err := sqlx.Open("sqlite", dbPath)
+ if err != nil {
+ panic(err)
+ }
+ // get SQLite version
+ res := db.QueryRow("select sqlite_version()")
+ fmt.Println(res)
+ return ProviderSQL{db: db}
+}
diff --git a/storage/storage_test.go b/storage/storage_test.go
new file mode 100644
index 0000000..0bf1fd6
--- /dev/null
+++ b/storage/storage_test.go
@@ -0,0 +1,88 @@
+package storage
+
+import (
+ "elefant/models"
+ "testing"
+ "time"
+
+ _ "github.com/glebarez/go-sqlite"
+ "github.com/jmoiron/sqlx"
+)
+
+func TestChatHistory(t *testing.T) {
+ // Create an in-memory SQLite database
+ db, err := sqlx.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("Failed to open SQLite in-memory database: %v", err)
+ }
+ defer db.Close()
+ // Create the chat table
+ _, err = db.Exec(`
+ CREATE TABLE chat (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ msgs TEXT NOT NULL, -- Store messages as a comma-separated string
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+ );`)
+ if err != nil {
+ t.Fatalf("Failed to create chat table: %v", err)
+ }
+ // Initialize the ProviderSQL struct
+ provider := ProviderSQL{db: db}
+ // List chats (should be empty)
+ chats, err := provider.ListChats()
+ if err != nil {
+ t.Fatalf("Failed to list chats: %v", err)
+ }
+ if len(chats) != 0 {
+ t.Errorf("Expected 0 chats, got %d", len(chats))
+ }
+ // Upsert a chat
+ chat := &models.Chat{
+ ID: 1,
+ Name: "Test Chat",
+ Msgs: "Hello World",
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ }
+ updatedChat, err := provider.UpsertChat(chat)
+ if err != nil {
+ t.Fatalf("Failed to upsert chat: %v", err)
+ }
+ if updatedChat == nil {
+ t.Errorf("Expected non-nil chat after upsert")
+ }
+ // Get chat by ID
+ fetchedChat, err := provider.GetChatByID(chat.ID)
+ if err != nil {
+ t.Fatalf("Failed to get chat by ID: %v", err)
+ }
+ if fetchedChat == nil {
+ t.Errorf("Expected non-nil chat after get")
+ }
+ if fetchedChat.Name != chat.Name {
+ t.Errorf("Expected chat name %s, got %s", chat.Name, fetchedChat.Name)
+ }
+ // List chats (should contain the upserted chat)
+ chats, err = provider.ListChats()
+ if err != nil {
+ t.Fatalf("Failed to list chats: %v", err)
+ }
+ if len(chats) != 1 {
+ t.Errorf("Expected 1 chat, got %d", len(chats))
+ }
+ // Remove chat
+ err = provider.RemoveChat(chat.ID)
+ if err != nil {
+ t.Fatalf("Failed to remove chat: %v", err)
+ }
+ // List chats (should be empty again)
+ chats, err = provider.ListChats()
+ if err != nil {
+ t.Fatalf("Failed to list chats: %v", err)
+ }
+ if len(chats) != 0 {
+ t.Errorf("Expected 0 chats, got %d", len(chats))
+ }
+}