summaryrefslogtreecommitdiff
path: root/storage/storage.go
blob: 57631da8488a3b946d02747b0da707be3b83204f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package storage

import (
	"gf-lt/models"
	"log/slog"

	_ "github.com/glebarez/go-sqlite"
	"github.com/jmoiron/sqlx"
)

type FullRepo interface {
	ChatHistory
	Memories
	VectorRepo
}

type ChatHistory interface {
	ListChats() ([]models.Chat, error)
	GetChatByID(id uint32) (*models.Chat, error)
	GetChatByChar(char string) ([]models.Chat, error)
	GetLastChat() (*models.Chat, error)
	GetLastChatByAgent(agent string) (*models.Chat, error)
	UpsertChat(chat *models.Chat) (*models.Chat, error)
	RemoveChat(id uint32) error
	ChatGetMaxID() (uint32, error)
}

type ProviderSQL struct {
	db     *sqlx.DB
	logger *slog.Logger
}

func (p ProviderSQL) ListChats() ([]models.Chat, error) {
	resp := []models.Chat{}
	err := p.db.Select(&resp, "SELECT * FROM chats;")
	return resp, err
}

func (p ProviderSQL) GetChatByChar(char string) ([]models.Chat, error) {
	resp := []models.Chat{}
	err := p.db.Select(&resp, "SELECT * FROM chats WHERE agent=$1;", char)
	return resp, err
}

func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) {
	resp := models.Chat{}
	err := p.db.Get(&resp, "SELECT * FROM chats WHERE id=$1;", id)
	return &resp, err
}

func (p ProviderSQL) GetLastChat() (*models.Chat, error) {
	resp := models.Chat{}
	err := p.db.Get(&resp, "SELECT * FROM chats ORDER BY updated_at DESC LIMIT 1")
	return &resp, err
}

func (p ProviderSQL) GetLastChatByAgent(agent string) (*models.Chat, error) {
	resp := models.Chat{}
	query := "SELECT * FROM chats WHERE agent=$1 ORDER BY updated_at DESC LIMIT 1"
	err := p.db.Get(&resp, query, agent)
	return &resp, err
}

// https://sqlite.org/lang_upsert.html
// on conflict was added
func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) {
	// Prepare the SQL statement
	query := `
        INSERT INTO chats (id, name, msgs, agent, created_at, updated_at)
	VALUES (:id, :name, :msgs, :agent, :created_at, :updated_at)
	ON CONFLICT(id) DO UPDATE SET msgs=excluded.msgs,
	updated_at=excluded.updated_at
        RETURNING *;`
	stmt, err := p.db.PrepareNamed(query)
	if err != nil {
		return nil, err
	}
	defer stmt.Close()
	// Execute the query and scan the result into a new chat object
	var resp models.Chat
	err = stmt.Get(&resp, chat)
	return &resp, err
}

func (p ProviderSQL) RemoveChat(id uint32) error {
	query := "DELETE FROM chats WHERE ID = $1;"
	_, err := p.db.Exec(query, id)
	return err
}

func (p ProviderSQL) ChatGetMaxID() (uint32, error) {
	query := "SELECT MAX(id) FROM chats;"
	var id uint32
	err := p.db.Get(&id, query)
	return id, err
}

// opens database connection
func NewProviderSQL(dbPath string, logger *slog.Logger) FullRepo {
	db, err := sqlx.Open("sqlite", dbPath)
	if err != nil {
		logger.Error("failed to open db connection", "error", err)
		return nil
	}
	// Enable WAL mode for better concurrency and performance
	if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
		logger.Warn("failed to enable WAL mode", "error", err)
	}
	if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
		logger.Warn("failed to set synchronous mode", "error", err)
	}
	// Increase cache size for better performance
	if _, err := db.Exec("PRAGMA cache_size = -2000;"); err != nil {
		logger.Warn("failed to set cache size", "error", err)
	}
	// Log actual journal mode for debugging
	var journalMode string
	if err := db.QueryRow("PRAGMA journal_mode;").Scan(&journalMode); err == nil {
		logger.Debug("SQLite journal mode", "mode", journalMode)
	}
	p := ProviderSQL{db: db, logger: logger}
	if err := p.Migrate(); err != nil {
		logger.Error("migration failed, app cannot start", "error", err)
		return nil
	}
	return p
}

// DB returns the underlying database connection
func (p ProviderSQL) DB() *sqlx.DB {
	return p.db
}