package database import ( "os" "path/filepath" "time" "log/slog" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" "github.com/pkg/errors" ) var ( log = slog.New(slog.NewJSONHandler(os.Stdout, nil)) dbDriver = "sqlite3" ) // VerifyDB checks if required tables exist func (d *DB) VerifyDB() error { var exists int err := d.Conn.Get(&exists, `SELECT 1 FROM sqlite_master WHERE type='table' AND name='questions' LIMIT 1`) if err != nil { return errors.Wrap(err, "database verification failed") } if exists != 1 { return errors.New("required tables not found in database") } return nil } type DB struct { Conn *sqlx.DB URI string } func (d *DB) CloseAll() error { for _, conn := range []*sqlx.DB{d.Conn} { if err := closeConn(conn); err != nil { return err } } return nil } func closeConn(conn *sqlx.DB) error { return conn.Close() } func Init(dbPath string) (*DB, error) { var result DB var err error // Default to in-memory DB if no path specified if dbPath == "" { dbPath = ":memory:" } else if !filepath.IsAbs(dbPath) { // Convert relative paths to absolute dbPath = filepath.Join(".", dbPath) } result.Conn, err = openDBConnection(dbPath, dbDriver) if err != nil { return nil, err } result.URI = dbPath if err := testConnection(result.Conn); err != nil { return nil, err } if err := result.VerifyDB(); err != nil { return nil, errors.Wrap(err, "database schema verification failed") } return &result, nil } func openDBConnection(dbPath, driver string) (*sqlx.DB, error) { // Ensure the directory exists if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { return nil, errors.Wrap(err, "failed to create db directory") } conn, err := sqlx.Open(driver, dbPath) if err != nil { return nil, err } // SQLite-specific optimizations conn.SetMaxOpenConns(1) // SQLite only supports one writer at a time conn.SetConnMaxLifetime(0) // Connections don't need to be closed/reopened return conn, nil } func testConnection(conn *sqlx.DB) error { err := conn.Ping() if err != nil { return errors.Wrap(err, "can't ping database") } return nil } func (d *DB) PingRoutine(interval time.Duration) { ticker := time.NewTicker(interval) done := make(chan bool) for { select { case <-done: return case t := <-ticker.C: if err := testConnection(d.Conn); err != nil { log.Error("failed to ping postrges db", "error", err, "ping_at", t) // reconnect if err := closeConn(d.Conn); err != nil { log.Error("failed to close db connection", "error", err, "ping_at", t) } d.Conn, err = openDBConnection(d.URI, dbDriver) if err != nil { log.Error("failed to reconnect", "error", err, "ping_at", t) } } } } }