summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2024-11-20 06:51:40 +0300
committerGrail Finder <wohilas@gmail.com>2024-11-20 06:51:40 +0300
commitaaf056663628f15bb6e4f23c899b6fd31bac5bf7 (patch)
treea0129fbb531e31d1a9d6465549a44b823906aaae
parentf32375488f5127c910021f627d83e017c5c7a10f (diff)
Enha: db chat management
-rw-r--r--README.md2
-rw-r--r--bot.go93
-rw-r--r--main.go16
-rw-r--r--models/db.go13
-rw-r--r--session.go92
-rw-r--r--storage/storage.go7
6 files changed, 127 insertions, 96 deletions
diff --git a/README.md b/README.md
index 036d9ab..7c7292f 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@
- tab to switch selection between textview and textarea (input and chat); +
- basic tools: memorize and recall;
- stop stream from the bot; +
-- sqlitedb instead of chatfiles;
+- sqlitedb instead of chatfiles; +
- sqlite for the bot memory;
### FIX:
diff --git a/bot.go b/bot.go
index 3a4fb6d..31354af 100644
--- a/bot.go
+++ b/bot.go
@@ -11,7 +11,6 @@ import (
"log/slog"
"net/http"
"os"
- "path"
"strings"
"time"
@@ -34,7 +33,7 @@ var (
historyDir = "./history/"
// TODO: pass as an cli arg
showSystemMsgs bool
- chatFileLoaded string
+ activeChatName string
chunkChan = make(chan string, 10)
streamDone = make(chan bool, 1)
chatBody *models.ChatBody
@@ -89,14 +88,12 @@ var fnMap = map[string]fnSig{
// ====
func getUserInput(userPrompt string) string {
- // fmt.Printf("<🤖>: %s\n<user>:", botMsg)
fmt.Printf(userPrompt)
reader := bufio.NewReader(os.Stdin)
line, err := reader.ReadString('\n')
if err != nil {
panic(err) // think about it
}
- // fmt.Printf("read line: %s-\n", line)
return line
}
@@ -152,7 +149,7 @@ func sendMsgToLLM(body io.Reader) (any, error) {
return nil, err
}
llmResp = append(llmResp, llmchunk)
- logger.Info("streamview", "chunk", llmchunk)
+ // logger.Info("streamview", "chunk", llmchunk)
// if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason != "chat.completion.chunk" {
if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" {
streamDone <- true
@@ -192,18 +189,12 @@ out:
})
// bot msg is done;
// now check it for func call
- logChat(chatFileLoaded, chatBody.Messages)
- findCall(respText.String(), tv)
-}
-
-func logChat(fname string, msgs []models.MessagesStory) {
- data, err := json.MarshalIndent(msgs, "", " ")
+ // logChat(activeChatName, chatBody.Messages)
+ err := updateStorageChat(activeChatName, chatBody.Messages)
if err != nil {
- logger.Error("failed to marshal", "error", err)
- }
- if err := os.WriteFile(fname, data, 0666); err != nil {
- logger.Error("failed to write log", "error", err)
+ logger.Warn("failed to update storage", "error", err, "name", activeChatName)
}
+ findCall(respText.String(), tv)
}
func findCall(msg string, tv *tview.TextView) {
@@ -235,74 +226,6 @@ func findCall(msg string, tv *tview.TextView) {
// return func result to the llm
}
-func listHistoryFiles(dir string) ([]string, error) {
- files, err := os.ReadDir(dir)
- if err != nil {
- logger.Error("failed to readdir", "error", err)
- return nil, err
- }
- resp := make([]string, len(files))
- for i, f := range files {
- resp[i] = path.Join(dir, f.Name())
- }
- return resp, nil
-}
-
-func findLatestChat(dir string) string {
- files, err := listHistoryFiles(dir)
- if err != nil {
- panic(err)
- }
- var (
- latestF string
- newestTime int64
- )
- logger.Info("filelist", "list", files)
- for _, fn := range files {
- fi, err := os.Stat(fn)
- if err != nil {
- logger.Error("failed to get stat", "error", err, "name", fn)
- panic(err)
- }
- currTime := fi.ModTime().Unix()
- if currTime > newestTime {
- newestTime = currTime
- latestF = fn
- }
- }
- return latestF
-}
-
-func readHistoryChat(fn string) ([]models.MessagesStory, error) {
- content, err := os.ReadFile(fn)
- if err != nil {
- logger.Error("failed to read file", "error", err, "name", fn)
- return nil, err
- }
- resp := []models.MessagesStory{}
- if err := json.Unmarshal(content, &resp); err != nil {
- logger.Error("failed to unmarshal", "error", err, "name", fn)
- return nil, err
- }
- chatFileLoaded = fn
- return resp, nil
-}
-
-func loadOldChatOrGetNew(fns ...string) []models.MessagesStory {
- // find last chat
- fn := findLatestChat(historyDir)
- if len(fns) > 0 {
- fn = fns[0]
- }
- logger.Info("reading history from file", "filename", fn)
- history, err := readHistoryChat(fn)
- if err != nil {
- logger.Warn("faield to load history chat", "error", err)
- return defaultStarter
- }
- return history
-}
-
func chatToTextSlice(showSys bool) []string {
resp := make([]string, len(chatBody.Messages))
for i, msg := range chatBody.Messages {
@@ -353,10 +276,13 @@ func init() {
if err := os.MkdirAll(historyDir, os.ModePerm); err != nil {
panic(err)
}
+ store = storage.NewProviderSQL("test.db")
// defer file.Close()
logger = slog.New(slog.NewTextHandler(file, nil))
logger.Info("test msg")
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
+ // load all chats in memory
+ loadHistoryChats()
lastChat := loadOldChatOrGetNew()
logger.Info("loaded history", "chat", lastChat)
chatBody = &models.ChatBody{
@@ -364,5 +290,4 @@ func init() {
Stream: true,
Messages: lastChat,
}
- store = storage.NewProviderSQL("test.db")
}
diff --git a/main.go b/main.go
index 7b90d28..9415f1a 100644
--- a/main.go
+++ b/main.go
@@ -60,7 +60,7 @@ func main() {
}
}
chatOpts := []string{"cancel", "new"}
- fList, err := listHistoryFiles(historyDir)
+ fList, err := loadHistoryChats()
if err != nil {
panic(err)
}
@@ -74,18 +74,16 @@ func main() {
// set chat body
chatBody.Messages = defaultStarter
textView.SetText(chatToText(showSystemMsgs))
- chatFileLoaded = path.Join(historyDir, fmt.Sprintf("%d_chat.json", time.Now().Unix()))
+ activeChatName = path.Join(historyDir, fmt.Sprintf("%d_chat.json", time.Now().Unix()))
pages.RemovePage("history")
return
// set text
case "cancel":
pages.RemovePage("history")
- // pages.ShowPage("main")
return
default:
- // fn := path.Join(historyDir, buttonLabel)
fn := buttonLabel
- history, err := readHistoryChat(fn)
+ history, err := loadHistoryChat(fn)
if err != nil {
logger.Error("failed to read history file", "filename", fn)
pages.RemovePage("history")
@@ -93,7 +91,7 @@ func main() {
}
chatBody.Messages = history
textView.SetText(chatToText(showSystemMsgs))
- chatFileLoaded = fn
+ activeChatName = fn
pages.RemovePage("history")
return
}
@@ -104,14 +102,11 @@ func main() {
editArea.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyEscape && editMode {
editedMsg := editArea.GetText()
- // TODO: trim msg number and icon
chatBody.Messages[selectedIndex].Content = editedMsg
// change textarea
textView.SetText(chatToText(showSystemMsgs))
pages.RemovePage("editArea")
editMode = false
- // panic("do we get here?")
- // pages.ShowPage("main")
return nil
}
return event
@@ -151,7 +146,8 @@ func main() {
textView.ScrollToEnd()
app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyF1 {
- fList, err := listHistoryFiles(historyDir)
+ // fList, err := listHistoryFiles(historyDir)
+ fList, err := loadHistoryChats()
if err != nil {
panic(err)
}
diff --git a/models/db.go b/models/db.go
index 24bef41..afd4b46 100644
--- a/models/db.go
+++ b/models/db.go
@@ -1,6 +1,9 @@
package models
-import "time"
+import (
+ "encoding/json"
+ "time"
+)
type Chat struct {
ID uint32 `db:"id" json:"id"`
@@ -9,3 +12,11 @@ type Chat struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
+
+func (c Chat) ToHistory() ([]MessagesStory, error) {
+ resp := []MessagesStory{}
+ if err := json.Unmarshal([]byte(c.Msgs), &resp); err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
diff --git a/session.go b/session.go
new file mode 100644
index 0000000..769fe90
--- /dev/null
+++ b/session.go
@@ -0,0 +1,92 @@
+package main
+
+import (
+ "elefant/models"
+ "encoding/json"
+ "fmt"
+ "time"
+)
+
+var (
+ chatMap = make(map[string]*models.Chat)
+)
+
+func historyToSJSON(msgs []models.MessagesStory) (string, error) {
+ data, err := json.Marshal(msgs)
+ if err != nil {
+ return "", err
+ }
+ if data == nil {
+ return "", fmt.Errorf("nil data")
+ }
+ return string(data), nil
+}
+
+func updateStorageChat(name string, msgs []models.MessagesStory) error {
+ var err error
+ chat, ok := chatMap[name]
+ if !ok {
+ err = fmt.Errorf("failed to find active chat; map:%v; key:%s", chatMap, name)
+ logger.Error("failed to find active chat", "map", chatMap, "key", name)
+ return err
+ }
+ chat.Msgs, err = historyToSJSON(msgs)
+ if err != nil {
+ return err
+ }
+ chat.UpdatedAt = time.Now()
+ _, err = store.UpsertChat(chat)
+ return err
+}
+
+func loadHistoryChats() ([]string, error) {
+ chats, err := store.ListChats()
+ if err != nil {
+ return nil, err
+ }
+ resp := []string{}
+ for _, chat := range chats {
+ if chat.Name == "" {
+ chat.Name = fmt.Sprintf("%d_%v", chat.ID, chat.CreatedAt.Unix())
+ }
+ resp = append(resp, chat.Name)
+ chatMap[chat.Name] = &chat
+ }
+ return resp, nil
+}
+
+func loadHistoryChat(chatName string) ([]models.MessagesStory, error) {
+ chat, ok := chatMap[chatName]
+ if !ok {
+ err := fmt.Errorf("failed to read chat")
+ logger.Error("failed to read chat", "name", chatName)
+ return nil, err
+ }
+ activeChatName = chatName
+ return chat.ToHistory()
+}
+
+func loadOldChatOrGetNew() []models.MessagesStory {
+ // find last chat
+ chat, err := store.GetLastChat()
+ newChat := &models.Chat{
+ ID: 0,
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ }
+ newChat.Name = fmt.Sprintf("%d_%v", chat.ID, chat.CreatedAt.Unix())
+ if err != nil {
+ logger.Warn("failed to load history chat", "error", err)
+ activeChatName = newChat.Name
+ chatMap[newChat.Name] = newChat
+ return defaultStarter
+ }
+ history, err := chat.ToHistory()
+ if err != nil {
+ logger.Warn("failed to load history chat", "error", err)
+ activeChatName = newChat.Name
+ chatMap[newChat.Name] = newChat
+ return defaultStarter
+ }
+ return history
+}
diff --git a/storage/storage.go b/storage/storage.go
index 11cbb4a..43162c8 100644
--- a/storage/storage.go
+++ b/storage/storage.go
@@ -11,6 +11,7 @@ import (
type ChatHistory interface {
ListChats() ([]models.Chat, error)
GetChatByID(id uint32) (*models.Chat, error)
+ GetLastChat() (*models.Chat, error)
UpsertChat(chat *models.Chat) (*models.Chat, error)
RemoveChat(id uint32) error
}
@@ -31,6 +32,12 @@ func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) {
return &resp, err
}
+func (p ProviderSQL) GetLastChat() (*models.Chat, error) {
+ resp := models.Chat{}
+ err := p.db.Get(&resp, "SELECT * FROM chat ORDER BY updated_at DESC LIMIT 1")
+ return &resp, err
+}
+
func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) {
// Prepare the SQL statement
query := `