summaryrefslogtreecommitdiff
path: root/internal/database/sql/main.go
blob: 6ca887929104974b97c87193e709136cfc019d62 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package database

import (
	"os"
	"path/filepath"
	"time"

	"log/slog"

	"github.com/jmoiron/sqlx"
	_ "github.com/mattn/go-sqlite3"
	"github.com/pkg/errors"
	"demoon/internal/models"
)

var (
	log      = slog.New(slog.NewJSONHandler(os.Stdout, nil))
	dbDriver = "sqlite3"
)

// VerifyDB checks if required tables exist
func (d *DB) VerifyDB() error {
	var exists int
	err := d.Conn.Get(&exists, `SELECT 1 FROM sqlite_master WHERE type='table' AND name='questions' LIMIT 1`)
	if err != nil {
		return errors.Wrap(err, "database verification failed")
	}
	if exists != 1 {
		return errors.New("required tables not found in database")
	}
	return nil
}

type DB struct {
	Conn *sqlx.DB
	URI  string
}

func (d *DB) CloseAll() error {
	for _, conn := range []*sqlx.DB{d.Conn} {
		if err := closeConn(conn); err != nil {
			return err
		}
	}
	return nil
}

func closeConn(conn *sqlx.DB) error {
	return conn.Close()
}

func Init(dbPath string) (*DB, error) {
	var result DB
	var err error
	
	// Default to in-memory DB if no path specified
	if dbPath == "" {
		dbPath = ":memory:"
	} else if !filepath.IsAbs(dbPath) {
		// Convert relative paths to absolute
		dbPath = filepath.Join(".", dbPath)
	}

	result.Conn, err = openDBConnection(dbPath, dbDriver)
	if err != nil {
		return nil, err
	}
	result.URI = dbPath
	if err := testConnection(result.Conn); err != nil {
		return nil, err
	}

	if err := result.VerifyDB(); err != nil {
		return nil, errors.Wrap(err, "database schema verification failed")
	}

	return &result, nil
}

func openDBConnection(dbPath, driver string) (*sqlx.DB, error) {
	// Ensure the directory exists
	if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
		return nil, errors.Wrap(err, "failed to create db directory")
	}

	conn, err := sqlx.Open(driver, dbPath)
	if err != nil {
		return nil, err
	}

	// SQLite-specific optimizations
	conn.SetMaxOpenConns(1) // SQLite only supports one writer at a time
	conn.SetConnMaxLifetime(0) // Connections don't need to be closed/reopened

	return conn, nil
}

func testConnection(conn *sqlx.DB) error {
	err := conn.Ping()
	if err != nil {
		return errors.Wrap(err, "can't ping database")
	}
	return nil
}

func (d *DB) PingRoutine(interval time.Duration) {
	ticker := time.NewTicker(interval)
	done := make(chan bool)
	for {
		select {
		case <-done:
			return
		case t := <-ticker.C:
			if err := testConnection(d.Conn); err != nil {
				log.Error("failed to ping postrges db", "error", err, "ping_at", t)
				// reconnect
				if err := closeConn(d.Conn); err != nil {
					log.Error("failed to close db connection", "error", err, "ping_at", t)
				}
				d.Conn, err = openDBConnection(d.URI, dbDriver)
				if err != nil {
					log.Error("failed to reconnect", "error", err, "ping_at", t)
				}
			}
		}
	}
}