diff options
author | Grail Finder <wohilas@gmail.com> | 2024-11-20 06:51:40 +0300 |
---|---|---|
committer | Grail Finder <wohilas@gmail.com> | 2024-11-20 06:51:40 +0300 |
commit | aaf056663628f15bb6e4f23c899b6fd31bac5bf7 (patch) | |
tree | a0129fbb531e31d1a9d6465549a44b823906aaae | |
parent | f32375488f5127c910021f627d83e017c5c7a10f (diff) |
Enha: db chat management
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | bot.go | 93 | ||||
-rw-r--r-- | main.go | 16 | ||||
-rw-r--r-- | models/db.go | 13 | ||||
-rw-r--r-- | session.go | 92 | ||||
-rw-r--r-- | storage/storage.go | 7 |
6 files changed, 127 insertions, 96 deletions
@@ -13,7 +13,7 @@ - tab to switch selection between textview and textarea (input and chat); + - basic tools: memorize and recall; - stop stream from the bot; + -- sqlitedb instead of chatfiles; +- sqlitedb instead of chatfiles; + - sqlite for the bot memory; ### FIX: @@ -11,7 +11,6 @@ import ( "log/slog" "net/http" "os" - "path" "strings" "time" @@ -34,7 +33,7 @@ var ( historyDir = "./history/" // TODO: pass as an cli arg showSystemMsgs bool - chatFileLoaded string + activeChatName string chunkChan = make(chan string, 10) streamDone = make(chan bool, 1) chatBody *models.ChatBody @@ -89,14 +88,12 @@ var fnMap = map[string]fnSig{ // ==== func getUserInput(userPrompt string) string { - // fmt.Printf("<🤖>: %s\n<user>:", botMsg) fmt.Printf(userPrompt) reader := bufio.NewReader(os.Stdin) line, err := reader.ReadString('\n') if err != nil { panic(err) // think about it } - // fmt.Printf("read line: %s-\n", line) return line } @@ -152,7 +149,7 @@ func sendMsgToLLM(body io.Reader) (any, error) { return nil, err } llmResp = append(llmResp, llmchunk) - logger.Info("streamview", "chunk", llmchunk) + // logger.Info("streamview", "chunk", llmchunk) // if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason != "chat.completion.chunk" { if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { streamDone <- true @@ -192,18 +189,12 @@ out: }) // bot msg is done; // now check it for func call - logChat(chatFileLoaded, chatBody.Messages) - findCall(respText.String(), tv) -} - -func logChat(fname string, msgs []models.MessagesStory) { - data, err := json.MarshalIndent(msgs, "", " ") + // logChat(activeChatName, chatBody.Messages) + err := updateStorageChat(activeChatName, chatBody.Messages) if err != nil { - logger.Error("failed to marshal", "error", err) - } - if err := os.WriteFile(fname, data, 0666); err != nil { - logger.Error("failed to write log", "error", err) + logger.Warn("failed to update storage", "error", err, "name", activeChatName) } + findCall(respText.String(), tv) } func findCall(msg string, tv *tview.TextView) { @@ -235,74 +226,6 @@ func findCall(msg string, tv *tview.TextView) { // return func result to the llm } -func listHistoryFiles(dir string) ([]string, error) { - files, err := os.ReadDir(dir) - if err != nil { - logger.Error("failed to readdir", "error", err) - return nil, err - } - resp := make([]string, len(files)) - for i, f := range files { - resp[i] = path.Join(dir, f.Name()) - } - return resp, nil -} - -func findLatestChat(dir string) string { - files, err := listHistoryFiles(dir) - if err != nil { - panic(err) - } - var ( - latestF string - newestTime int64 - ) - logger.Info("filelist", "list", files) - for _, fn := range files { - fi, err := os.Stat(fn) - if err != nil { - logger.Error("failed to get stat", "error", err, "name", fn) - panic(err) - } - currTime := fi.ModTime().Unix() - if currTime > newestTime { - newestTime = currTime - latestF = fn - } - } - return latestF -} - -func readHistoryChat(fn string) ([]models.MessagesStory, error) { - content, err := os.ReadFile(fn) - if err != nil { - logger.Error("failed to read file", "error", err, "name", fn) - return nil, err - } - resp := []models.MessagesStory{} - if err := json.Unmarshal(content, &resp); err != nil { - logger.Error("failed to unmarshal", "error", err, "name", fn) - return nil, err - } - chatFileLoaded = fn - return resp, nil -} - -func loadOldChatOrGetNew(fns ...string) []models.MessagesStory { - // find last chat - fn := findLatestChat(historyDir) - if len(fns) > 0 { - fn = fns[0] - } - logger.Info("reading history from file", "filename", fn) - history, err := readHistoryChat(fn) - if err != nil { - logger.Warn("faield to load history chat", "error", err) - return defaultStarter - } - return history -} - func chatToTextSlice(showSys bool) []string { resp := make([]string, len(chatBody.Messages)) for i, msg := range chatBody.Messages { @@ -353,10 +276,13 @@ func init() { if err := os.MkdirAll(historyDir, os.ModePerm); err != nil { panic(err) } + store = storage.NewProviderSQL("test.db") // defer file.Close() logger = slog.New(slog.NewTextHandler(file, nil)) logger.Info("test msg") // https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md + // load all chats in memory + loadHistoryChats() lastChat := loadOldChatOrGetNew() logger.Info("loaded history", "chat", lastChat) chatBody = &models.ChatBody{ @@ -364,5 +290,4 @@ func init() { Stream: true, Messages: lastChat, } - store = storage.NewProviderSQL("test.db") } @@ -60,7 +60,7 @@ func main() { } } chatOpts := []string{"cancel", "new"} - fList, err := listHistoryFiles(historyDir) + fList, err := loadHistoryChats() if err != nil { panic(err) } @@ -74,18 +74,16 @@ func main() { // set chat body chatBody.Messages = defaultStarter textView.SetText(chatToText(showSystemMsgs)) - chatFileLoaded = path.Join(historyDir, fmt.Sprintf("%d_chat.json", time.Now().Unix())) + activeChatName = path.Join(historyDir, fmt.Sprintf("%d_chat.json", time.Now().Unix())) pages.RemovePage("history") return // set text case "cancel": pages.RemovePage("history") - // pages.ShowPage("main") return default: - // fn := path.Join(historyDir, buttonLabel) fn := buttonLabel - history, err := readHistoryChat(fn) + history, err := loadHistoryChat(fn) if err != nil { logger.Error("failed to read history file", "filename", fn) pages.RemovePage("history") @@ -93,7 +91,7 @@ func main() { } chatBody.Messages = history textView.SetText(chatToText(showSystemMsgs)) - chatFileLoaded = fn + activeChatName = fn pages.RemovePage("history") return } @@ -104,14 +102,11 @@ func main() { editArea.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEscape && editMode { editedMsg := editArea.GetText() - // TODO: trim msg number and icon chatBody.Messages[selectedIndex].Content = editedMsg // change textarea textView.SetText(chatToText(showSystemMsgs)) pages.RemovePage("editArea") editMode = false - // panic("do we get here?") - // pages.ShowPage("main") return nil } return event @@ -151,7 +146,8 @@ func main() { textView.ScrollToEnd() app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyF1 { - fList, err := listHistoryFiles(historyDir) + // fList, err := listHistoryFiles(historyDir) + fList, err := loadHistoryChats() if err != nil { panic(err) } diff --git a/models/db.go b/models/db.go index 24bef41..afd4b46 100644 --- a/models/db.go +++ b/models/db.go @@ -1,6 +1,9 @@ package models -import "time" +import ( + "encoding/json" + "time" +) type Chat struct { ID uint32 `db:"id" json:"id"` @@ -9,3 +12,11 @@ type Chat struct { CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } + +func (c Chat) ToHistory() ([]MessagesStory, error) { + resp := []MessagesStory{} + if err := json.Unmarshal([]byte(c.Msgs), &resp); err != nil { + return nil, err + } + return resp, nil +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..769fe90 --- /dev/null +++ b/session.go @@ -0,0 +1,92 @@ +package main + +import ( + "elefant/models" + "encoding/json" + "fmt" + "time" +) + +var ( + chatMap = make(map[string]*models.Chat) +) + +func historyToSJSON(msgs []models.MessagesStory) (string, error) { + data, err := json.Marshal(msgs) + if err != nil { + return "", err + } + if data == nil { + return "", fmt.Errorf("nil data") + } + return string(data), nil +} + +func updateStorageChat(name string, msgs []models.MessagesStory) error { + var err error + chat, ok := chatMap[name] + if !ok { + err = fmt.Errorf("failed to find active chat; map:%v; key:%s", chatMap, name) + logger.Error("failed to find active chat", "map", chatMap, "key", name) + return err + } + chat.Msgs, err = historyToSJSON(msgs) + if err != nil { + return err + } + chat.UpdatedAt = time.Now() + _, err = store.UpsertChat(chat) + return err +} + +func loadHistoryChats() ([]string, error) { + chats, err := store.ListChats() + if err != nil { + return nil, err + } + resp := []string{} + for _, chat := range chats { + if chat.Name == "" { + chat.Name = fmt.Sprintf("%d_%v", chat.ID, chat.CreatedAt.Unix()) + } + resp = append(resp, chat.Name) + chatMap[chat.Name] = &chat + } + return resp, nil +} + +func loadHistoryChat(chatName string) ([]models.MessagesStory, error) { + chat, ok := chatMap[chatName] + if !ok { + err := fmt.Errorf("failed to read chat") + logger.Error("failed to read chat", "name", chatName) + return nil, err + } + activeChatName = chatName + return chat.ToHistory() +} + +func loadOldChatOrGetNew() []models.MessagesStory { + // find last chat + chat, err := store.GetLastChat() + newChat := &models.Chat{ + ID: 0, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + newChat.Name = fmt.Sprintf("%d_%v", chat.ID, chat.CreatedAt.Unix()) + if err != nil { + logger.Warn("failed to load history chat", "error", err) + activeChatName = newChat.Name + chatMap[newChat.Name] = newChat + return defaultStarter + } + history, err := chat.ToHistory() + if err != nil { + logger.Warn("failed to load history chat", "error", err) + activeChatName = newChat.Name + chatMap[newChat.Name] = newChat + return defaultStarter + } + return history +} diff --git a/storage/storage.go b/storage/storage.go index 11cbb4a..43162c8 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -11,6 +11,7 @@ import ( type ChatHistory interface { ListChats() ([]models.Chat, error) GetChatByID(id uint32) (*models.Chat, error) + GetLastChat() (*models.Chat, error) UpsertChat(chat *models.Chat) (*models.Chat, error) RemoveChat(id uint32) error } @@ -31,6 +32,12 @@ func (p ProviderSQL) GetChatByID(id uint32) (*models.Chat, error) { return &resp, err } +func (p ProviderSQL) GetLastChat() (*models.Chat, error) { + resp := models.Chat{} + err := p.db.Get(&resp, "SELECT * FROM chat ORDER BY updated_at DESC LIMIT 1") + return &resp, err +} + func (p ProviderSQL) UpsertChat(chat *models.Chat) (*models.Chat, error) { // Prepare the SQL statement query := ` |