From f32375488f5127c910021f627d83e017c5c7a10f Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Tue, 19 Nov 2024 17:15:02 +0300 Subject: Feat: add storage interface; add sqlite impl --- .gitignore | 1 + README.md | 6 ++- bot.go | 16 ++++++- go.mod | 10 ++++ go.sum | 31 +++++++++++++ main.go | 18 +++++--- models/db.go | 11 +++++ storage/migrations/001_init.down.sql | 1 + storage/migrations/001_init.up.sql | 7 +++ storage/storage.go | 65 ++++++++++++++++++++++++++ storage/storage_test.go | 88 ++++++++++++++++++++++++++++++++++++ 11 files changed, 244 insertions(+), 10 deletions(-) create mode 100644 models/db.go create mode 100644 storage/migrations/001_init.down.sql create mode 100644 storage/migrations/001_init.up.sql create mode 100644 storage/storage.go create mode 100644 storage/storage_test.go diff --git a/.gitignore b/.gitignore index a149479..9c37cd9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ testlog elefant history/ +*.db diff --git a/README.md b/README.md index db53835..036d9ab 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,10 @@ - fullscreen textarea option (for long prompt); - tab to switch selection between textview and textarea (input and chat); + - basic tools: memorize and recall; -- stop stream from the bot; +- stop stream from the bot; + +- sqlitedb instead of chatfiles; +- sqlite for the bot memory; ### FIX: - bot responding (or haninging) blocks everything; + -- programm requires history folder, but it is .gitignore; +- programm requires history folder, but it is .gitignore; + diff --git a/bot.go b/bot.go index be35a09..3a4fb6d 100644 --- a/bot.go +++ b/bot.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "elefant/models" + "elefant/storage" "encoding/json" "fmt" "io" @@ -37,12 +38,14 @@ var ( chunkChan = make(chan string, 10) streamDone = make(chan bool, 1) chatBody *models.ChatBody + store storage.ChatHistory defaultFirstMsg = "Hello! What can I do for you?" defaultStarter = []models.MessagesStory{ {Role: "system", Content: systemMsg}, {Role: assistantRole, Content: defaultFirstMsg}, } - systemMsg = `You're a helpful assistant. + interruptResp = false + systemMsg = `You're a helpful assistant. # Tools You can do functions call if needed. Your current tools: @@ -116,11 +119,17 @@ func sendMsgToLLM(body io.Reader) (any, error) { logger.Error("llamacpp api", "error", err) return nil, err } + defer resp.Body.Close() llmResp := []models.LLMRespChunk{} // chunkChan <- assistantIcon reader := bufio.NewReader(resp.Body) counter := 0 for { + if interruptResp { + interruptResp = false + logger.Info("interrupted bot response") + break + } llmchunk := models.LLMRespChunk{} if counter > 2000 { streamDone <- true @@ -340,6 +349,10 @@ func init() { if err != nil { panic(err) } + // create dir if does not exist + if err := os.MkdirAll(historyDir, os.ModePerm); err != nil { + panic(err) + } // defer file.Close() logger = slog.New(slog.NewTextHandler(file, nil)) logger.Info("test msg") @@ -351,4 +364,5 @@ func init() { Stream: true, Messages: lastChat, } + store = storage.NewProviderSQL("test.db") } diff --git a/go.mod b/go.mod index 3d84b7f..687511d 100644 --- a/go.mod +++ b/go.mod @@ -4,15 +4,25 @@ go 1.23.2 require ( github.com/gdamore/tcell/v2 v2.7.4 + github.com/glebarez/go-sqlite v1.22.0 + github.com/jmoiron/sqlx v1.4.0 github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 ) require ( + github.com/dustin/go-humanize v1.0.1 // indirect github.com/gdamore/encoding v1.0.0 // indirect + github.com/google/uuid v1.5.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect golang.org/x/sys v0.17.0 // indirect golang.org/x/term v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect + modernc.org/libc v1.37.6 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect + modernc.org/sqlite v1.28.0 // indirect ) diff --git a/go.sum b/go.sum index 74a393e..e4a23b5 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,33 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= github.com/gdamore/tcell/v2 v2.7.4 h1:sg6/UnTM9jGpZU+oFYAsDahfchWAFW8Xx2yFinNSAYU= github.com/gdamore/tcell/v2 v2.7.4/go.mod h1:dSXtXTSK0VsW1biw65DZLZ2NKr7j0qP/0J7ONmsraWg= +github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= +github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= +github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592 h1:YIJ+B1hePP6AgynC5TcqpO0H9k3SSoZa2BGyL6vDUzM= github.com/rivo/tview v0.0.0-20241103174730-c76f7879f592/go.mod h1:02iFIz7K/A9jGCvrizLPvoqr4cEIx7q54RH5Qudkrss= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -30,6 +52,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -48,3 +71,11 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +modernc.org/libc v1.37.6 h1:orZH3c5wmhIQFTXF+Nt+eeauyd+ZIt2BX6ARe+kD+aw= +modernc.org/libc v1.37.6/go.mod h1:YAXkAZ8ktnkCKaN9sw/UDeUVkGYJ/YquGO4FTi5nmHE= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= +modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= +modernc.org/sqlite v1.28.0 h1:Zx+LyDDmXczNnEQdvPuEfcFVA2ZPyaD7UCZDjef3BHQ= +modernc.org/sqlite v1.28.0/go.mod h1:Qxpazz0zH8Z1xCFyi5GSL3FzbtZ3fvbjmywNogldEW0= diff --git a/main.go b/main.go index f808902..7b90d28 100644 --- a/main.go +++ b/main.go @@ -12,12 +12,11 @@ import ( ) var ( - normalMode = false botRespMode = false editMode = false botMsg = "no" selectedIndex = int(-1) - indexLine = "Esc: send msg; Tab: switch focus; F1: manage chats; F2: regen last; F3:delete msg menu; F4: edit msg; F5: toggle system; Row: [yellow]%d[white], Column: [yellow]%d; normal mode: %v" + indexLine = "Esc: send msg; Tab: switch focus; F1: manage chats; F2: regen last; F3:delete last msg; F4: edit msg; F5: toggle system; F6: interrupt bot resp; Row: [yellow]%d[white], Column: [yellow]%d; bot resp mode: %v" focusSwitcher = map[tview.Primitive]tview.Primitive{} ) @@ -55,9 +54,9 @@ func main() { updateStatusLine := func() { fromRow, fromColumn, toRow, toColumn := textArea.GetCursor() if fromRow == toRow && fromColumn == toColumn { - position.SetText(fmt.Sprintf(indexLine, fromRow, fromColumn, normalMode)) + position.SetText(fmt.Sprintf(indexLine, fromRow, fromColumn, botRespMode)) } else { - position.SetText(fmt.Sprintf("Esc: send msg; Tab: switch focus; F1: manage chats; F2: regen last; F3:delete msg menu; F4: edit msg; F5: toggle system; Row: [yellow]%d[white], Column: [yellow]%d[white] - [red]To[white] Row: [yellow]%d[white], To Column: [yellow]%d; normal mode: %v", fromRow, fromColumn, toRow, toColumn, normalMode)) + position.SetText(fmt.Sprintf("Esc: send msg; Tab: switch focus; F1: manage chats; F2: regen last; F3:delete last msg; F4: edit msg; F5: toggle system; F6: interrupt bot resp; Row: [yellow]%d[white], Column: [yellow]%d[white] - [red]To[white] Row: [yellow]%d[white], To Column: [yellow]%d; bot resp mode: %v", fromRow, fromColumn, toRow, toColumn, botRespMode)) } } chatOpts := []string{"cancel", "new"} @@ -186,10 +185,15 @@ func main() { showSystemMsgs = !showSystemMsgs textView.SetText(chatToText(showSystemMsgs)) } + if event.Key() == tcell.KeyF6 { + interruptResp = true + botRespMode = false + return nil + } // cannot send msg in editMode or botRespMode if event.Key() == tcell.KeyEscape && !editMode && !botRespMode { fromRow, fromColumn, _, _ := textArea.GetCursor() - position.SetText(fmt.Sprintf(indexLine, fromRow, fromColumn, normalMode)) + position.SetText(fmt.Sprintf(indexLine, fromRow, fromColumn, botRespMode)) // read all text into buffer msgText := textArea.GetText() if msgText != "" { @@ -206,9 +210,9 @@ func main() { app.SetFocus(focusSwitcher[currentF]) } if isASCII(string(event.Rune())) && !botRespMode { - // normalMode = false + // botRespMode = false // fromRow, fromColumn, _, _ := textArea.GetCursor() - // position.SetText(fmt.Sprintf(indexLine, fromRow, fromColumn, normalMode)) + // position.SetText(fmt.Sprintf(indexLine, fromRow, fromColumn, botRespMode)) return event } return event diff --git a/models/db.go b/models/db.go new file mode 100644 index 0000000..24bef41 --- /dev/null +++ b/models/db.go @@ -0,0 +1,11 @@ +package models + +import "time" + +type Chat struct { + ID uint32 `db:"id" json:"id"` + Name string `db:"name" json:"name"` + Msgs string `db:"msgs" json:"msgs"` // []MessagesStory to string json + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} diff --git a/storage/migrations/001_init.down.sql b/storage/migrations/001_init.down.sql new file mode 100644 index 0000000..0ef183f --- /dev/null +++ b/storage/migrations/001_init.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS chat; diff --git a/storage/migrations/001_init.up.sql b/storage/migrations/001_init.up.sql new file mode 100644 index 0000000..287f3d1 --- /dev/null +++ b/storage/migrations/001_init.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE chat ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + msgs TEXT NOT NULL, -- Store messages as a comma-separated string + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/storage/storage.go b/storage/storage.go new file mode 100644 index 0000000..11cbb4a --- /dev/null +++ b/storage/storage.go @@ -0,0 +1,65 @@ +package storage + +import ( + "elefant/models" + "fmt" + + _ "github.com/glebarez/go-sqlite" + "github.com/jmoiron/sqlx" +) + +type ChatHistory interface { + ListChats() ([]models.Chat, error) + GetChatByID(id uint32) (*models.Chat, error) + UpsertChat(chat *models.Chat) (*models.Chat, error) + RemoveChat(id uint32) error +} + +type ProviderSQL struct { + db *sqlx.DB +} + +func (p ProviderSQL) ListChats() ([]models.Chat, error) { + resp := []models.Chat{} + err := p.db.Select(&resp, "SELECT * FROM chat;") + return resp, err +} + +func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { + resp := models.Chat{} + err := p.db.Get(&resp, "SELECT * FROM chat WHERE id=$1;", id) + return &resp, err +} + +func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { + // Prepare the SQL statement + query := ` + INSERT OR REPLACE INTO chat (id, name, msgs, created_at, updated_at) + VALUES (:id, :name, :msgs, :created_at, :updated_at) + RETURNING *;` + stmt, err := p.db.PrepareNamed(query) + if err != nil { + return nil, err + } + // Execute the query and scan the result into a new chat object + var resp models.Chat + err = stmt.Get(&resp, chat) + return &resp, err +} + +func (p ProviderSQL) RemoveChat(id uint32) error { + query := "DELETE FROM chat WHERE ID = $1;" + _, err := p.db.Exec(query, id) + return err +} + +func NewProviderSQL(dbPath string) ChatHistory { + db, err := sqlx.Open("sqlite", dbPath) + if err != nil { + panic(err) + } + // get SQLite version + res := db.QueryRow("select sqlite_version()") + fmt.Println(res) + return ProviderSQL{db: db} +} diff --git a/storage/storage_test.go b/storage/storage_test.go new file mode 100644 index 0000000..0bf1fd6 --- /dev/null +++ b/storage/storage_test.go @@ -0,0 +1,88 @@ +package storage + +import ( + "elefant/models" + "testing" + "time" + + _ "github.com/glebarez/go-sqlite" + "github.com/jmoiron/sqlx" +) + +func TestChatHistory(t *testing.T) { + // Create an in-memory SQLite database + db, err := sqlx.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open SQLite in-memory database: %v", err) + } + defer db.Close() + // Create the chat table + _, err = db.Exec(` + CREATE TABLE chat ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + msgs TEXT NOT NULL, -- Store messages as a comma-separated string + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + );`) + if err != nil { + t.Fatalf("Failed to create chat table: %v", err) + } + // Initialize the ProviderSQL struct + provider := ProviderSQL{db: db} + // List chats (should be empty) + chats, err := provider.ListChats() + if err != nil { + t.Fatalf("Failed to list chats: %v", err) + } + if len(chats) != 0 { + t.Errorf("Expected 0 chats, got %d", len(chats)) + } + // Upsert a chat + chat := &models.Chat{ + ID: 1, + Name: "Test Chat", + Msgs: "Hello World", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + updatedChat, err := provider.UpsertChat(chat) + if err != nil { + t.Fatalf("Failed to upsert chat: %v", err) + } + if updatedChat == nil { + t.Errorf("Expected non-nil chat after upsert") + } + // Get chat by ID + fetchedChat, err := provider.GetChatByID(chat.ID) + if err != nil { + t.Fatalf("Failed to get chat by ID: %v", err) + } + if fetchedChat == nil { + t.Errorf("Expected non-nil chat after get") + } + if fetchedChat.Name != chat.Name { + t.Errorf("Expected chat name %s, got %s", chat.Name, fetchedChat.Name) + } + // List chats (should contain the upserted chat) + chats, err = provider.ListChats() + if err != nil { + t.Fatalf("Failed to list chats: %v", err) + } + if len(chats) != 1 { + t.Errorf("Expected 1 chat, got %d", len(chats)) + } + // Remove chat + err = provider.RemoveChat(chat.ID) + if err != nil { + t.Fatalf("Failed to remove chat: %v", err) + } + // List chats (should be empty again) + chats, err = provider.ListChats() + if err != nil { + t.Fatalf("Failed to list chats: %v", err) + } + if len(chats) != 0 { + t.Errorf("Expected 0 chats, got %d", len(chats)) + } +} -- cgit v1.2.3