summaryrefslogtreecommitdiff
path: root/bot.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2026-01-31 12:57:53 +0300
committerGrail Finder <wohilas@gmail.com>2026-01-31 12:57:53 +0300
commit3a11210f52a850f84771e1642cafcc3027b85075 (patch)
treeaa4ec3f49b4ed8221a045fc221b09b26bee2c15d /bot.go
parentfa192a262410eb98b42ff8fb9e0f4e1111240514 (diff)
Enha: avoid recursion in llm calls
Diffstat (limited to 'bot.go')
-rw-r--r--bot.go145
1 files changed, 94 insertions, 51 deletions
diff --git a/bot.go b/bot.go
index 1a2cebb..6e7d094 100644
--- a/bot.go
+++ b/bot.go
@@ -25,17 +25,16 @@ import (
"time"
"github.com/neurosnap/sentences/english"
- "github.com/rivo/tview"
)
var (
- httpClient = &http.Client{}
- cfg *config.Config
- logger *slog.Logger
- logLevel = new(slog.LevelVar)
-)
-var (
+ httpClient = &http.Client{}
+ cfg *config.Config
+ logger *slog.Logger
+ logLevel = new(slog.LevelVar)
+ ctx, cancel = context.WithCancel(context.Background())
activeChatName string
+ chatRoundChan = make(chan *models.ChatRoundReq, 1)
chunkChan = make(chan string, 10)
openAIToolChan = make(chan string, 10)
streamDone = make(chan bool, 1)
@@ -699,7 +698,23 @@ func roleToIcon(role string) string {
return "<" + role + ">: "
}
-func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
+func chatWatcher(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case chatRoundReq := <-chatRoundChan:
+ if err := chatRound(chatRoundReq); err != nil {
+ logger.Error("failed to chatRound", "err", err)
+ }
+ }
+ }
+}
+
+func chatRound(r *models.ChatRoundReq) error {
+ // chunkChan := make(chan string, 10)
+ // openAIToolChan := make(chan string, 10)
+ // streamDone := make(chan bool, 1)
botRespMode = true
botPersona := cfg.AssistantRole
if cfg.WriteNextMsgAsCompletionAgent != "" {
@@ -707,32 +722,23 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
}
defer func() { botRespMode = false }()
// check that there is a model set to use if is not local
- if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
- if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
- if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
- logger.Warn("failed ot notify user", "error", err)
- return
- }
- return
- }
- }
choseChunkParser()
- reader, err := chunkParser.FormMsg(userMsg, role, resume)
+ reader, err := chunkParser.FormMsg(r.UserMsg, r.Role, r.Resume)
if reader == nil || err != nil {
- logger.Error("empty reader from msgs", "role", role, "error", err)
- return
+ logger.Error("empty reader from msgs", "role", r.Role, "error", err)
+ return err
}
if cfg.SkipLLMResp {
- return
+ return nil
}
go sendMsgToLLM(reader)
- logger.Debug("looking at vars in chatRound", "msg", userMsg, "regen", regen, "resume", resume)
- if !resume {
- fmt.Fprintf(tv, "\n[-:-:b](%d) ", len(chatBody.Messages))
- fmt.Fprint(tv, roleToIcon(botPersona))
- fmt.Fprint(tv, "[-:-:-]\n")
+ logger.Debug("looking at vars in chatRound", "msg", r.UserMsg, "regen", r.Regen, "resume", r.Resume)
+ if !r.Resume {
+ fmt.Fprintf(textView, "\n[-:-:b](%d) ", len(chatBody.Messages))
+ fmt.Fprint(textView, roleToIcon(botPersona))
+ fmt.Fprint(textView, "[-:-:-]\n")
if cfg.ThinkUse && !strings.Contains(cfg.CurrentAPI, "v1") {
- // fmt.Fprint(tv, "<think>")
+ // fmt.Fprint(textView, "<think>")
chunkChan <- "<think>"
}
}
@@ -742,29 +748,29 @@ out:
for {
select {
case chunk := <-chunkChan:
- fmt.Fprint(tv, chunk)
+ fmt.Fprint(textView, chunk)
respText.WriteString(chunk)
if scrollToEndEnabled {
- tv.ScrollToEnd()
+ textView.ScrollToEnd()
}
// Send chunk to audio stream handler
if cfg.TTS_ENABLED {
TTSTextChan <- chunk
}
case toolChunk := <-openAIToolChan:
- fmt.Fprint(tv, toolChunk)
+ fmt.Fprint(textView, toolChunk)
toolResp.WriteString(toolChunk)
if scrollToEndEnabled {
- tv.ScrollToEnd()
+ textView.ScrollToEnd()
}
case <-streamDone:
// drain any remaining chunks from chunkChan before exiting
for len(chunkChan) > 0 {
chunk := <-chunkChan
- fmt.Fprint(tv, chunk)
+ fmt.Fprint(textView, chunk)
respText.WriteString(chunk)
if scrollToEndEnabled {
- tv.ScrollToEnd()
+ textView.ScrollToEnd()
}
if cfg.TTS_ENABLED {
// Send chunk to audio stream handler
@@ -780,7 +786,7 @@ out:
}
botRespMode = false
// numbers in chatbody and displayed must be the same
- if resume {
+ if r.Resume {
chatBody.Messages[len(chatBody.Messages)-1].Content += respText.String()
// lastM.Content = lastM.Content + respText.String()
// Process the updated message to check for known_to tags in resumed response
@@ -797,7 +803,9 @@ out:
}
logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages))
for i, msg := range chatBody.Messages {
- logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
+ logger.Debug("chatRound: before cleaning", "index", i,
+ "role", msg.Role, "content_len", len(msg.Content),
+ "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
}
// // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template
cleanChatBody()
@@ -813,8 +821,9 @@ out:
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil {
logger.Warn("failed to update storage", "error", err, "name", activeChatName)
}
- // FIXME: recursive calls
- findCall(respText.String(), toolResp.String(), tv)
+ if findCall(respText.String(), toolResp.String()) {
+ return nil
+ }
// TODO: have a config attr
// Check if this message was sent privately to specific characters
// If so, trigger those characters to respond if that char is not controlled by user
@@ -822,9 +831,10 @@ out:
if cfg.AutoTurn {
lastMsg := chatBody.Messages[len(chatBody.Messages)-1]
if len(lastMsg.KnownTo) > 0 {
- triggerPrivateMessageResponses(lastMsg, tv)
+ triggerPrivateMessageResponses(lastMsg)
}
}
+ return nil
}
// cleanChatBody removes messages with null or empty content to prevent API issues
@@ -909,7 +919,8 @@ func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) {
return fc, nil
}
-func findCall(msg, toolCall string, tv *tview.TextView) {
+// findCall: adds chatRoundReq into the chatRoundChan and returns true if does
+func findCall(msg, toolCall string) bool {
fc := &models.FuncCall{}
if toolCall != "" {
// HTML-decode the tool call string to handle encoded characters like &lt; -> <=
@@ -927,8 +938,13 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
// Clear the stored tool call ID after using it (no longer needed)
// Trigger the assistant to continue processing with the error message
- chatRound("", cfg.AssistantRole, tv, false, false)
- return
+ crr := &models.ChatRoundReq{
+ Role: cfg.AssistantRole,
+ }
+ // provoke next llm msg after failed tool call
+ chatRoundChan <- crr
+ // chatRound("", cfg.AssistantRole, tv, false, false)
+ return true
}
lastToolCall.Args = openAIToolMap
fc = lastToolCall
@@ -940,8 +956,8 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
}
} else {
jsStr := toolCallRE.FindString(msg)
- if jsStr == "" {
- return
+ if jsStr == "" { // no tool call case
+ return false
}
prefix := "__tool_call__\n"
suffix := "\n__tool_call__"
@@ -960,8 +976,13 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
logger.Debug("findCall: added tool error response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "message_count_after_add", len(chatBody.Messages))
// Trigger the assistant to continue processing with the error message
- chatRound("", cfg.AssistantRole, tv, false, false)
- return
+ // chatRound("", cfg.AssistantRole, tv, false, false)
+ crr := &models.ChatRoundReq{
+ Role: cfg.AssistantRole,
+ }
+ // provoke next llm msg after failed tool call
+ chatRoundChan <- crr
+ return true
}
// Update lastToolCall with parsed function call
lastToolCall.ID = fc.ID
@@ -994,13 +1015,17 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
lastToolCall.ID = ""
// Trigger the assistant to continue processing with the new tool response
// by calling chatRound with empty content to continue the assistant's response
- chatRound("", cfg.AssistantRole, tv, false, false)
- return
+ crr := &models.ChatRoundReq{
+ Role: cfg.AssistantRole,
+ }
+ // failed to find tool
+ chatRoundChan <- crr
+ return true
}
resp := callToolWithAgent(fc.Name, fc.Args)
toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting
logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc)
- fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
+ fmt.Fprintf(textView, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
"\n\n", len(chatBody.Messages), cfg.ToolRole, toolMsg)
// Create tool response message with the proper tool_call_id
toolResponseMsg := models.RoleMsg{
@@ -1014,7 +1039,11 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
lastToolCall.ID = ""
// Trigger the assistant to continue processing with the new tool response
// by calling chatRound with empty content to continue the assistant's response
- chatRound("", cfg.AssistantRole, tv, false, false)
+ crr := &models.ChatRoundReq{
+ Role: cfg.AssistantRole,
+ }
+ chatRoundChan <- crr
+ return true
}
func chatToTextSlice(messages []models.RoleMsg, showSys bool) []string {
@@ -1163,10 +1192,12 @@ func summarizeAndStartNewChat() {
}
func init() {
+ // ctx, cancel := context.WithCancel(context.Background())
var err error
cfg, err = config.LoadConfig("config.toml")
if err != nil {
fmt.Println("failed to load config.toml")
+ cancel()
os.Exit(1)
return
}
@@ -1178,6 +1209,8 @@ func init() {
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
slog.Error("failed to open log file", "error", err, "filename", cfg.LogFile)
+ cancel()
+ os.Exit(1)
return
}
// load cards
@@ -1188,13 +1221,17 @@ func init() {
logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel}))
store = storage.NewProviderSQL(cfg.DBPATH, logger)
if store == nil {
+ cancel()
os.Exit(1)
+ return
}
ragger = rag.New(logger, store, cfg)
// https://github.com/coreydaley/ggerganov-llama.cpp/blob/master/examples/server/README.md
// load all chats in memory
if _, err := loadHistoryChats(); err != nil {
logger.Error("failed to load chat", "error", err)
+ cancel()
+ os.Exit(1)
return
}
lastToolCall = &models.FuncCall{}
@@ -1215,11 +1252,12 @@ func init() {
// Initialize scrollToEndEnabled based on config
scrollToEndEnabled = cfg.AutoScrollEnabled
go updateModelLists()
+ go chatWatcher(ctx)
}
// triggerPrivateMessageResponses checks if a message was sent privately to specific characters
// and triggers those non-user characters to respond
-func triggerPrivateMessageResponses(msg models.RoleMsg, tv *tview.TextView) {
+func triggerPrivateMessageResponses(msg models.RoleMsg) {
if cfg == nil || !cfg.CharSpecificContextEnabled {
return
}
@@ -1237,6 +1275,11 @@ func triggerPrivateMessageResponses(msg models.RoleMsg, tv *tview.TextView) {
// that indicates it's their turn
triggerMsg := recipient + ":\n"
// Call chatRound with the trigger message to make the recipient respond
- chatRound(triggerMsg, recipient, tv, false, false)
+ // chatRound(triggerMsg, recipient, tv, false, false)
+ crr := &models.ChatRoundReq{
+ UserMsg: triggerMsg,
+ Role: recipient,
+ }
+ chatRoundChan <- crr
}
}