summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2024-11-20 13:21:51 +0300
committerGrail Finder <wohilas@gmail.com>2024-11-20 13:21:51 +0300
commit8ae4d075c4d043eae604af9cad0cf5e571420a61 (patch)
tree3b0eb7689a598e25dac67d889e037f3bc08c19a1
parent74669b58fe7b58b3d2fd4ad88c03890bc53a7a1a (diff)
Feat: migration on startup
-rw-r--r--README.md1
-rw-r--r--bot.go3
-rw-r--r--storage/migrate.go59
-rw-r--r--storage/migrations/001_init.up.sql2
-rw-r--r--storage/storage.go13
5 files changed, 69 insertions, 9 deletions
diff --git a/README.md b/README.md
index 7c7292f..5d071f5 100644
--- a/README.md
+++ b/README.md
@@ -19,3 +19,4 @@
### FIX:
- bot responding (or haninging) blocks everything; +
- programm requires history folder, but it is .gitignore; +
+- at first run chat table does not exist; run migrations sql on startup; +
diff --git a/bot.go b/bot.go
index 31354af..4587c08 100644
--- a/bot.go
+++ b/bot.go
@@ -276,10 +276,9 @@ func init() {
if err := os.MkdirAll(historyDir, os.ModePerm); err != nil {
panic(err)
}
- store = storage.NewProviderSQL("test.db")
- // defer file.Close()
logger = slog.New(slog.NewTextHandler(file, nil))
logger.Info("test msg")
+ store = storage.NewProviderSQL("test.db", logger)
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
// load all chats in memory
loadHistoryChats()
diff --git a/storage/migrate.go b/storage/migrate.go
new file mode 100644
index 0000000..d97b99d
--- /dev/null
+++ b/storage/migrate.go
@@ -0,0 +1,59 @@
+package storage
+
+import (
+ "embed"
+ "fmt"
+ "io/fs"
+ "strings"
+)
+
+//go:embed migrations/*
+var migrationsFS embed.FS
+
+func (p *ProviderSQL) Migrate() {
+ // Get the embedded filesystem
+ migrationsDir, err := fs.Sub(migrationsFS, "migrations")
+ if err != nil {
+ p.logger.Error("Failed to get embedded migrations directory;", "error", err)
+ }
+ // List all .up.sql files
+ files, err := migrationsFS.ReadDir("migrations")
+ if err != nil {
+ p.logger.Error("Failed to read migrations directory;", "error", err)
+ }
+ // Execute each .up.sql file
+ for _, file := range files {
+ if strings.HasSuffix(file.Name(), ".up.sql") {
+ err := p.executeMigration(migrationsDir, file.Name())
+ if err != nil {
+ p.logger.Error("Failed to execute migration %s: %v", file.Name(), err)
+ }
+ }
+ }
+ p.logger.Info("All migrations executed successfully!")
+}
+
+func (p *ProviderSQL) executeMigration(migrationsDir fs.FS, fileName string) error {
+ // Open the migration file
+ migrationFile, err := migrationsDir.Open(fileName)
+ if err != nil {
+ return fmt.Errorf("failed to open migration file %s: %w", fileName, err)
+ }
+ defer migrationFile.Close()
+ // Read the migration file content
+ migrationContent, err := fs.ReadFile(migrationsDir, fileName)
+ if err != nil {
+ return fmt.Errorf("failed to read migration file %s: %w", fileName, err)
+ }
+ // Execute the migration content
+ return p.executeSQL(migrationContent)
+}
+
+func (p *ProviderSQL) executeSQL(sqlContent []byte) error {
+ // Connect to the database (example using a simple connection)
+ _, err := p.db.Exec(string(sqlContent))
+ if err != nil {
+ return fmt.Errorf("failed to execute SQL: %w", err)
+ }
+ return nil
+}
diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql
index 287f3d1..1b3e63d 100644
--- a/storage/migrations/001_init.up.sql
+++ b/storage/migrations/001_init.up.sql
@@ -1,4 +1,4 @@
-CREATE TABLE chat (
+CREATE TABLE IF NOT EXISTS chat (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
msgs TEXT NOT NULL, -- Store messages as a comma-separated string
diff --git a/storage/storage.go b/storage/storage.go
index 43162c8..edbd393 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -2,7 +2,7 @@ package storage
import (
"elefant/models"
- "fmt"
+ "log/slog"
_ "github.com/glebarez/go-sqlite"
"github.com/jmoiron/sqlx"
@@ -17,7 +17,8 @@ type ChatHistory interface {
}
type ProviderSQL struct {
- db *sqlx.DB
+ db *sqlx.DB
+ logger *slog.Logger
}
func (p ProviderSQL) ListChats() ([]models.Chat, error) {
@@ -60,13 +61,13 @@ func (p ProviderSQL) RemoveChat(id uint32) error {
return err
}
-func NewProviderSQL(dbPath string) ChatHistory {
+func NewProviderSQL(dbPath string, logger *slog.Logger) ChatHistory {
db, err := sqlx.Open("sqlite", dbPath)
if err != nil {
panic(err)
}
// get SQLite version
- res := db.QueryRow("select sqlite_version()")
- fmt.Println(res)
- return ProviderSQL{db: db}
+ p := ProviderSQL{db: db, logger: logger}
+ p.Migrate()
+ return p
}