diff options
-rw-r--r-- | bot.go | 67 | ||||
-rw-r--r-- | config.example.toml | 9 | ||||
-rw-r--r-- | config/config.go | 37 | ||||
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | tools.go | 6 | ||||
-rw-r--r-- | tui.go | 24 |
7 files changed, 95 insertions, 51 deletions
@@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "elefant/config" "elefant/models" "elefant/storage" "encoding/json" @@ -22,27 +23,16 @@ var httpClient = http.Client{ } var ( - logger *slog.Logger - userRole = "user" - assistantRole = "assistant" - toolRole = "tool" - assistantIcon = "<🤖>: " - userIcon = "<user>: " - // TODO: pass as an cli arg or have config - APIURL = "http://localhost:8080/v1/chat/completions" - logFileName = "log.txt" - showSystemMsgs = true - chunkLimit = 1000 - activeChatName string - chunkChan = make(chan string, 10) - streamDone = make(chan bool, 1) - chatBody *models.ChatBody - store storage.FullRepo - defaultFirstMsg = "Hello! What can I do for you?" - defaultStarter = []models.MessagesStory{ - {Role: "system", Content: systemMsg}, - {Role: assistantRole, Content: defaultFirstMsg}, - } + cfg *config.Config + logger *slog.Logger + chunkLimit = 1000 + activeChatName string + chunkChan = make(chan string, 10) + streamDone = make(chan bool, 1) + chatBody *models.ChatBody + store storage.FullRepo + defaultFirstMsg = "Hello! What can I do for you?" + defaultStarter = []models.MessagesStory{} defaultStarterBytes = []byte{} interruptResp = false ) @@ -64,14 +54,14 @@ func formMsg(chatBody *models.ChatBody, newMsg, role string) io.Reader { // func sendMsgToLLM(body io.Reader) (*models.LLMRespChunk, error) { func sendMsgToLLM(body io.Reader) (any, error) { - resp, err := httpClient.Post(APIURL, "application/json", body) + resp, err := httpClient.Post(cfg.APIURL, "application/json", body) if err != nil { logger.Error("llamacpp api", "error", err) return nil, err } defer resp.Body.Close() llmResp := []models.LLMRespChunk{} - // chunkChan <- assistantIcon + // chunkChan <- cfg.AssistantIcon reader := bufio.NewReader(resp.Body) counter := 0 for { @@ -128,7 +118,7 @@ func chatRound(userMsg, role string, tv *tview.TextView) { go sendMsgToLLM(reader) if userMsg != "" { // no need to write assistant icon since we continue old message fmt.Fprintf(tv, fmt.Sprintf("(%d) ", len(chatBody.Messages))) - fmt.Fprintf(tv, assistantIcon) + fmt.Fprintf(tv, cfg.AssistantIcon) } respText := strings.Builder{} out: @@ -145,7 +135,7 @@ out: } botRespMode = false chatBody.Messages = append(chatBody.Messages, models.MessagesStory{ - Role: assistantRole, Content: respText.String(), + Role: cfg.AssistantRole, Content: respText.String(), }) // bot msg is done; // now check it for func call @@ -174,18 +164,18 @@ func findCall(msg string, tv *tview.TextView) { f, ok := fnMap[fc.Name] if !ok { m := fmt.Sprintf("%s is not implemented", fc.Name) - chatRound(m, toolRole, tv) + chatRound(m, cfg.ToolRole, tv) return } resp := f(fc.Args...) toolMsg := fmt.Sprintf("tool response: %+v", string(resp)) - chatRound(toolMsg, toolRole, tv) + chatRound(toolMsg, cfg.ToolRole, tv) } func chatToTextSlice(showSys bool) []string { resp := make([]string, len(chatBody.Messages)) for i, msg := range chatBody.Messages { - if !showSys && (msg.Role != assistantRole && msg.Role != userRole) { + if !showSys && (msg.Role != cfg.AssistantRole && msg.Role != cfg.UserRole) { continue } resp[i] = msg.ToText(i) @@ -201,14 +191,14 @@ func chatToText(showSys bool) string { func textToMsg(rawMsg string) models.MessagesStory { msg := models.MessagesStory{} // system and tool? - if strings.HasPrefix(rawMsg, assistantIcon) { - msg.Role = assistantRole - msg.Content = strings.TrimPrefix(rawMsg, assistantIcon) + if strings.HasPrefix(rawMsg, cfg.AssistantIcon) { + msg.Role = cfg.AssistantRole + msg.Content = strings.TrimPrefix(rawMsg, cfg.AssistantIcon) return msg } - if strings.HasPrefix(rawMsg, userIcon) { - msg.Role = userRole - msg.Content = strings.TrimPrefix(rawMsg, userIcon) + if strings.HasPrefix(rawMsg, cfg.UserIcon) { + msg.Role = cfg.UserRole + msg.Content = strings.TrimPrefix(rawMsg, cfg.UserIcon) return msg } return msg @@ -224,9 +214,14 @@ func textSliceToChat(chat []string) []models.MessagesStory { } func init() { - file, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + cfg = config.LoadConfigOrDefault("config.example.toml") + defaultStarter = []models.MessagesStory{ + {Role: "system", Content: systemMsg}, + {Role: cfg.AssistantRole, Content: defaultFirstMsg}, + } + file, err := os.OpenFile(cfg.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - logger.Error("failed to open log file", "error", err, "filename", logFileName) + logger.Error("failed to open log file", "error", err, "filename", cfg.LogFile) return } defaultStarterBytes, err = json.Marshal(defaultStarter) diff --git a/config.example.toml b/config.example.toml new file mode 100644 index 0000000..d1388e5 --- /dev/null +++ b/config.example.toml @@ -0,0 +1,9 @@ +APIURL = "http://localhost:8080/v1/chat/completions" +ShowSys = true +LogFile = "log.txt" +UserRole = "user" +ToolRole = "tool" +AssistantRole = "assistant" +AssistantIcon = "<🤖>: " +UserIcon = "<user>: " +ToolIcon = "<ï‚>>: " diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..27f8c66 --- /dev/null +++ b/config/config.go @@ -0,0 +1,37 @@ +package config + +import ( + "fmt" + + "github.com/BurntSushi/toml" +) + +type Config struct { + APIURL string `toml:"APIURL"` + ShowSys bool `toml:"ShowSys"` + LogFile string `toml:"LogFile"` + UserRole string `toml:"UserRole"` + ToolRole string `toml:"ToolRole"` + AssistantRole string `toml:"AssistantRole"` + AssistantIcon string `toml:"AssistantIcon"` + UserIcon string `toml:"UserIcon"` + ToolIcon string `toml:"ToolIcon"` +} + +func LoadConfigOrDefault(fn string) *Config { + if fn == "" { + fn = "config.toml" + } + config := &Config{} + _, err := toml.DecodeFile(fn, &config) + if err != nil { + fmt.Println("failed to read config from file, loading default") + config.APIURL = "http://localhost:8080/v1/chat/completions" + config.ShowSys = true + config.LogFile = "log.txt" + config.UserRole = "user" + config.ToolRole = "tool" + config.AssistantRole = "assistant" + } + return config +} @@ -3,6 +3,7 @@ module elefant go 1.23.2 require ( + github.com/BurntSushi/toml v1.4.0 github.com/gdamore/tcell/v2 v2.7.4 github.com/glebarez/go-sqlite v1.22.0 github.com/jmoiron/sqlx v1.4.0 @@ -1,5 +1,7 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= +github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= @@ -61,7 +61,7 @@ also: - others do; */ func memorise(args ...string) []byte { - agent := assistantRole + agent := cfg.AssistantRole if len(args) < 2 { msg := "not enough args to call memorise tool; need topic and data to remember" logger.Error(msg) @@ -79,7 +79,7 @@ func memorise(args ...string) []byte { } func recall(args ...string) []byte { - agent := assistantRole + agent := cfg.AssistantRole if len(args) < 1 { logger.Warn("not enough args to call recall tool") return nil @@ -94,7 +94,7 @@ func recall(args ...string) []byte { } func recallTopics(args ...string) []byte { - agent := assistantRole + agent := cfg.AssistantRole topics, err := store.RecallTopics(agent) if err != nil { logger.Error("failed to use tool", "error", err, "args", args) @@ -82,7 +82,7 @@ func init() { } // set chat body chatBody.Messages = defaultStarter - textView.SetText(chatToText(showSystemMsgs)) + textView.SetText(chatToText(cfg.ShowSys)) newChat := &models.Chat{ ID: id + 1, Name: fmt.Sprintf("%v_%v", "new", time.Now().Unix()), @@ -111,7 +111,7 @@ func init() { return } chatBody.Messages = history - textView.SetText(chatToText(showSystemMsgs)) + textView.SetText(chatToText(cfg.ShowSys)) activeChatName = fn pages.RemovePage("history") return @@ -134,7 +134,7 @@ func init() { } chatBody.Messages[0].Content = sysMsg // replace textview - textView.SetText(chatToText(showSystemMsgs)) + textView.SetText(chatToText(cfg.ShowSys)) pages.RemovePage("sys") } }) @@ -152,7 +152,7 @@ func init() { } chatBody.Messages[selectedIndex].Content = editedMsg // change textarea - textView.SetText(chatToText(showSystemMsgs)) + textView.SetText(chatToText(cfg.ShowSys)) pages.RemovePage("editArea") editMode = false return nil @@ -233,7 +233,7 @@ func init() { // textArea.SetMovedFunc(updateStatusLine) updateStatusLine() - textView.SetText(chatToText(showSystemMsgs)) + textView.SetText(chatToText(cfg.ShowSys)) textView.ScrollToEnd() app.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyF1 { @@ -251,14 +251,14 @@ func init() { if event.Key() == tcell.KeyF2 { // regen last msg chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] - textView.SetText(chatToText(showSystemMsgs)) - go chatRound("", userRole, textView) + textView.SetText(chatToText(cfg.ShowSys)) + go chatRound("", cfg.UserRole, textView) return nil } if event.Key() == tcell.KeyF3 && !botRespMode { // delete last msg chatBody.Messages = chatBody.Messages[:len(chatBody.Messages)-1] - textView.SetText(chatToText(showSystemMsgs)) + textView.SetText(chatToText(cfg.ShowSys)) return nil } if event.Key() == tcell.KeyF4 { @@ -268,9 +268,9 @@ func init() { return nil } if event.Key() == tcell.KeyF5 { - // switch showSystemMsgs - showSystemMsgs = !showSystemMsgs - textView.SetText(chatToText(showSystemMsgs)) + // switch cfg.ShowSys + cfg.ShowSys = !cfg.ShowSys + textView.SetText(chatToText(cfg.ShowSys)) } if event.Key() == tcell.KeyF6 { interruptResp = true @@ -317,7 +317,7 @@ func init() { textView.ScrollToEnd() } // update statue line - go chatRound(msgText, userRole, textView) + go chatRound(msgText, cfg.UserRole, textView) return nil } if event.Key() == tcell.KeyPgUp || event.Key() == tcell.KeyPgDn { |