From 3a11210f52a850f84771e1642cafcc3027b85075 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Sat, 31 Jan 2026 12:57:53 +0300 Subject: Enha: avoid recursion in llm calls --- bot.go | 145 ++++++++++++++++++++++++++++++++++++------------------- helpfuncs.go | 13 +++++ llm.go | 6 +++ models/models.go | 19 +++++--- tui.go | 9 ++-- 5 files changed, 132 insertions(+), 60 deletions(-) diff --git a/bot.go b/bot.go index 1a2cebb..6e7d094 100644 --- a/bot.go +++ b/bot.go @@ -25,17 +25,16 @@ import ( "time" "github.com/neurosnap/sentences/english" - "github.com/rivo/tview" ) var ( - httpClient = &http.Client{} - cfg *config.Config - logger *slog.Logger - logLevel = new(slog.LevelVar) -) -var ( + httpClient = &http.Client{} + cfg *config.Config + logger *slog.Logger + logLevel = new(slog.LevelVar) + ctx, cancel = context.WithCancel(context.Background()) activeChatName string + chatRoundChan = make(chan *models.ChatRoundReq, 1) chunkChan = make(chan string, 10) openAIToolChan = make(chan string, 10) streamDone = make(chan bool, 1) @@ -699,7 +698,23 @@ func roleToIcon(role string) string { return "<" + role + ">: " } -func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { +func chatWatcher(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case chatRoundReq := <-chatRoundChan: + if err := chatRound(chatRoundReq); err != nil { + logger.Error("failed to chatRound", "err", err) + } + } + } +} + +func chatRound(r *models.ChatRoundReq) error { + // chunkChan := make(chan string, 10) + // openAIToolChan := make(chan string, 10) + // streamDone := make(chan bool, 1) botRespMode = true botPersona := cfg.AssistantRole if cfg.WriteNextMsgAsCompletionAgent != "" { @@ -707,32 +722,23 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { } defer func() { botRespMode = false }() // check that there is a model set to use if is not local - if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI { - if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" { - if err := notifyUser("bad request", "wrong deepseek model name"); err != nil { - logger.Warn("failed ot notify user", "error", err) - return - } - return - } - } choseChunkParser() - reader, err := chunkParser.FormMsg(userMsg, role, resume) + reader, err := chunkParser.FormMsg(r.UserMsg, r.Role, r.Resume) if reader == nil || err != nil { - logger.Error("empty reader from msgs", "role", role, "error", err) - return + logger.Error("empty reader from msgs", "role", r.Role, "error", err) + return err } if cfg.SkipLLMResp { - return + return nil } go sendMsgToLLM(reader) - logger.Debug("looking at vars in chatRound", "msg", userMsg, "regen", regen, "resume", resume) - if !resume { - fmt.Fprintf(tv, "\n[-:-:b](%d) ", len(chatBody.Messages)) - fmt.Fprint(tv, roleToIcon(botPersona)) - fmt.Fprint(tv, "[-:-:-]\n") + logger.Debug("looking at vars in chatRound", "msg", r.UserMsg, "regen", r.Regen, "resume", r.Resume) + if !r.Resume { + fmt.Fprintf(textView, "\n[-:-:b](%d) ", len(chatBody.Messages)) + fmt.Fprint(textView, roleToIcon(botPersona)) + fmt.Fprint(textView, "[-:-:-]\n") if cfg.ThinkUse && !strings.Contains(cfg.CurrentAPI, "v1") { - // fmt.Fprint(tv, "") + // fmt.Fprint(textView, "") chunkChan <- "" } } @@ -742,29 +748,29 @@ out: for { select { case chunk := <-chunkChan: - fmt.Fprint(tv, chunk) + fmt.Fprint(textView, chunk) respText.WriteString(chunk) if scrollToEndEnabled { - tv.ScrollToEnd() + textView.ScrollToEnd() } // Send chunk to audio stream handler if cfg.TTS_ENABLED { TTSTextChan <- chunk } case toolChunk := <-openAIToolChan: - fmt.Fprint(tv, toolChunk) + fmt.Fprint(textView, toolChunk) toolResp.WriteString(toolChunk) if scrollToEndEnabled { - tv.ScrollToEnd() + textView.ScrollToEnd() } case <-streamDone: // drain any remaining chunks from chunkChan before exiting for len(chunkChan) > 0 { chunk := <-chunkChan - fmt.Fprint(tv, chunk) + fmt.Fprint(textView, chunk) respText.WriteString(chunk) if scrollToEndEnabled { - tv.ScrollToEnd() + textView.ScrollToEnd() } if cfg.TTS_ENABLED { // Send chunk to audio stream handler @@ -780,7 +786,7 @@ out: } botRespMode = false // numbers in chatbody and displayed must be the same - if resume { + if r.Resume { chatBody.Messages[len(chatBody.Messages)-1].Content += respText.String() // lastM.Content = lastM.Content + respText.String() // Process the updated message to check for known_to tags in resumed response @@ -797,7 +803,9 @@ out: } logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages)) for i, msg := range chatBody.Messages { - logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + logger.Debug("chatRound: before cleaning", "index", i, + "role", msg.Role, "content_len", len(msg.Content), + "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) } // // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template cleanChatBody() @@ -813,8 +821,9 @@ out: if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { logger.Warn("failed to update storage", "error", err, "name", activeChatName) } - // FIXME: recursive calls - findCall(respText.String(), toolResp.String(), tv) + if findCall(respText.String(), toolResp.String()) { + return nil + } // TODO: have a config attr // Check if this message was sent privately to specific characters // If so, trigger those characters to respond if that char is not controlled by user @@ -822,9 +831,10 @@ out: if cfg.AutoTurn { lastMsg := chatBody.Messages[len(chatBody.Messages)-1] if len(lastMsg.KnownTo) > 0 { - triggerPrivateMessageResponses(lastMsg, tv) + triggerPrivateMessageResponses(lastMsg) } } + return nil } // cleanChatBody removes messages with null or empty content to prevent API issues @@ -909,7 +919,8 @@ func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) { return fc, nil } -func findCall(msg, toolCall string, tv *tview.TextView) { +// findCall: adds chatRoundReq into the chatRoundChan and returns true if does +func findCall(msg, toolCall string) bool { fc := &models.FuncCall{} if toolCall != "" { // HTML-decode the tool call string to handle encoded characters like < -> <= @@ -927,8 +938,13 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatBody.Messages = append(chatBody.Messages, toolResponseMsg) // Clear the stored tool call ID after using it (no longer needed) // Trigger the assistant to continue processing with the error message - chatRound("", cfg.AssistantRole, tv, false, false) - return + crr := &models.ChatRoundReq{ + Role: cfg.AssistantRole, + } + // provoke next llm msg after failed tool call + chatRoundChan <- crr + // chatRound("", cfg.AssistantRole, tv, false, false) + return true } lastToolCall.Args = openAIToolMap fc = lastToolCall @@ -940,8 +956,8 @@ func findCall(msg, toolCall string, tv *tview.TextView) { } } else { jsStr := toolCallRE.FindString(msg) - if jsStr == "" { - return + if jsStr == "" { // no tool call case + return false } prefix := "__tool_call__\n" suffix := "\n__tool_call__" @@ -960,8 +976,13 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatBody.Messages = append(chatBody.Messages, toolResponseMsg) logger.Debug("findCall: added tool error response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "message_count_after_add", len(chatBody.Messages)) // Trigger the assistant to continue processing with the error message - chatRound("", cfg.AssistantRole, tv, false, false) - return + // chatRound("", cfg.AssistantRole, tv, false, false) + crr := &models.ChatRoundReq{ + Role: cfg.AssistantRole, + } + // provoke next llm msg after failed tool call + chatRoundChan <- crr + return true } // Update lastToolCall with parsed function call lastToolCall.ID = fc.ID @@ -994,13 +1015,17 @@ func findCall(msg, toolCall string, tv *tview.TextView) { lastToolCall.ID = "" // Trigger the assistant to continue processing with the new tool response // by calling chatRound with empty content to continue the assistant's response - chatRound("", cfg.AssistantRole, tv, false, false) - return + crr := &models.ChatRoundReq{ + Role: cfg.AssistantRole, + } + // failed to find tool + chatRoundChan <- crr + return true } resp := callToolWithAgent(fc.Name, fc.Args) toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc) - fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", + fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", "\n\n", len(chatBody.Messages), cfg.ToolRole, toolMsg) // Create tool response message with the proper tool_call_id toolResponseMsg := models.RoleMsg{ @@ -1014,7 +1039,11 @@ func findCall(msg, toolCall string, tv *tview.TextView) { lastToolCall.ID = "" // Trigger the assistant to continue processing with the new tool response // by calling chatRound with empty content to continue the assistant's response - chatRound("", cfg.AssistantRole, tv, false, false) + crr := &models.ChatRoundReq{ + Role: cfg.AssistantRole, + } + chatRoundChan <- crr + return true } func chatToTextSlice(messages []models.RoleMsg, showSys bool) []string { @@ -1163,10 +1192,12 @@ func summarizeAndStartNewChat() { } func init() { + // ctx, cancel := context.WithCancel(context.Background()) var err error cfg, err = config.LoadConfig("config.toml") if err != nil { fmt.Println("failed to load config.toml") + cancel() os.Exit(1) return } @@ -1178,6 +1209,8 @@ func init() { os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { slog.Error("failed to open log file", "error", err, "filename", cfg.LogFile) + cancel() + os.Exit(1) return } // load cards @@ -1188,13 +1221,17 @@ func init() { logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel})) store = storage.NewProviderSQL(cfg.DBPATH, logger) if store == nil { + cancel() os.Exit(1) + return } ragger = rag.New(logger, store, cfg) // https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md // load all chats in memory if _, err := loadHistoryChats(); err != nil { logger.Error("failed to load chat", "error", err) + cancel() + os.Exit(1) return } lastToolCall = &models.FuncCall{} @@ -1215,11 +1252,12 @@ func init() { // Initialize scrollToEndEnabled based on config scrollToEndEnabled = cfg.AutoScrollEnabled go updateModelLists() + go chatWatcher(ctx) } // triggerPrivateMessageResponses checks if a message was sent privately to specific characters // and triggers those non-user characters to respond -func triggerPrivateMessageResponses(msg models.RoleMsg, tv *tview.TextView) { +func triggerPrivateMessageResponses(msg models.RoleMsg) { if cfg == nil || !cfg.CharSpecificContextEnabled { return } @@ -1237,6 +1275,11 @@ func triggerPrivateMessageResponses(msg models.RoleMsg, tv *tview.TextView) { // that indicates it's their turn triggerMsg := recipient + ":\n" // Call chatRound with the trigger message to make the recipient respond - chatRound(triggerMsg, recipient, tv, false, false) + // chatRound(triggerMsg, recipient, tv, false, false) + crr := &models.ChatRoundReq{ + UserMsg: triggerMsg, + Role: recipient, + } + chatRoundChan <- crr } } diff --git a/helpfuncs.go b/helpfuncs.go index 28e7962..849b0a0 100644 --- a/helpfuncs.go +++ b/helpfuncs.go @@ -279,3 +279,16 @@ func listChatRoles() []string { charset = append(charset, cbc...) return charset } + +func deepseekModelValidator() error { + if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI { + if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" { + if err := notifyUser("bad request", "wrong deepseek model name"); err != nil { + logger.Warn("failed ot notify user", "error", err) + return err + } + return nil + } + } + return nil +} diff --git a/llm.go b/llm.go index cd5a3fe..5bd7554 100644 --- a/llm.go +++ b/llm.go @@ -363,6 +363,9 @@ func (ds DeepSeekerCompletion) GetToken() string { func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader, error) { logger.Debug("formmsg deepseekercompletion", "link", cfg.CurrentAPI) + if err := deepseekModelValidator(); err != nil { + return nil, err + } if msg != "" { // otherwise let the bot to continue newMsg := models.RoleMsg{Role: role, Content: msg} newMsg = processMessageTag(newMsg) @@ -445,6 +448,9 @@ func (ds DeepSeekerChat) GetToken() string { func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, error) { logger.Debug("formmsg deepseekerchat", "link", cfg.CurrentAPI) + if err := deepseekModelValidator(); err != nil { + return nil, err + } if msg != "" { // otherwise let the bot continue newMsg := models.RoleMsg{Role: role, Content: msg} newMsg = processMessageTag(newMsg) diff --git a/models/models.go b/models/models.go index 69bdf02..76ef183 100644 --- a/models/models.go +++ b/models/models.go @@ -116,9 +116,9 @@ func (m RoleMsg) MarshalJSON() ([]byte, error) { } else { // Use simple content format aux := struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` KnownTo []string `json:"known_to,omitempty"` }{ Role: m.Role, @@ -150,9 +150,9 @@ func (m *RoleMsg) UnmarshalJSON(data []byte) error { // Otherwise, unmarshal as simple content format var simple struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ToolCallID string `json:"tool_call_id,omitempty"` KnownTo []string `json:"known_to,omitempty"` } if err := json.Unmarshal(data, &simple); err != nil { @@ -540,3 +540,10 @@ func (lcp *LCPModels) ListModels() []string { } return resp } + +type ChatRoundReq struct { + UserMsg string + Role string + Regen bool + Resume bool +} diff --git a/tui.go b/tui.go index d222d15..e164423 100644 --- a/tui.go +++ b/tui.go @@ -873,7 +873,8 @@ func init() { // there is no case where user msg is regenerated // lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role textView.SetText(chatToText(chatBody.Messages, cfg.ShowSys)) - go chatRound("", cfg.UserRole, textView, true, false) + // go chatRound("", cfg.UserRole, textView, true, false) + chatRoundChan <- &models.ChatRoundReq{Role: cfg.UserRole} return nil } if event.Key() == tcell.KeyF3 && !botRespMode { @@ -1176,7 +1177,8 @@ func init() { // INFO: continue bot/text message // without new role lastRole := chatBody.Messages[len(chatBody.Messages)-1].Role - go chatRound("", lastRole, textView, false, true) + // go chatRound("", lastRole, textView, false, true) + chatRoundChan <- &models.ChatRoundReq{Role: lastRole, Resume: true} return nil } if event.Key() == tcell.KeyCtrlQ { @@ -1347,7 +1349,8 @@ func init() { } colorText() } - go chatRound(msgText, persona, textView, false, false) + // go chatRound(msgText, persona, textView, false, false) + chatRoundChan <- &models.ChatRoundReq{Role: persona, UserMsg: msgText} // Also clear any image attachment after sending the message go func() { // Wait a short moment for the message to be processed, then clear the image attachment -- cgit v1.2.3