From 589dfdda3fa89ecc984530ce3bfcc58ee2fd851d Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Fri, 8 Aug 2025 10:51:14 +0300 Subject: Feat: tool chunk channel for openai tool calls --- bot.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) (limited to 'bot.go') diff --git a/bot.go b/bot.go index 423a7cb..72a0c44 100644 --- a/bot.go +++ b/bot.go @@ -34,6 +34,7 @@ var ( logLevel = new(slog.LevelVar) activeChatName string chunkChan = make(chan string, 10) + openAIToolChan = make(chan string, 10) streamDone = make(chan bool, 1) chatBody *models.ChatBody store storage.FullRepo @@ -189,8 +190,8 @@ func sendMsgToLLM(body io.Reader) { for { var ( answerText string - content string stop bool + chunk *models.TextChunk ) counter++ // to stop from spiriling in infinity read of bad bytes that happens with poor connection @@ -225,7 +226,7 @@ func sendMsgToLLM(body io.Reader) { if bytes.Equal(line, []byte("ROUTER PROCESSING\n")) { continue } - content, stop, err = chunkParser.ParseChunk(line) + chunk, err = chunkParser.ParseChunk(line) if err != nil { logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI) @@ -239,18 +240,19 @@ func sendMsgToLLM(body io.Reader) { break } if stop { - if content != "" { - logger.Warn("text inside of finish llmchunk", "chunk", content, "counter", counter) + if chunk.Chunk != "" { + logger.Warn("text inside of finish llmchunk", "chunk", chunk, "counter", counter) } streamDone <- true break } if counter == 0 { - content = strings.TrimPrefix(content, " ") + chunk.Chunk = strings.TrimPrefix(chunk.Chunk, " ") } // bot sends way too many \n - answerText = strings.ReplaceAll(content, "\n\n", "\n") + answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n") chunkChan <- answerText + openAIToolChan <- chunk.ToolChunk interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req interruptResp = false @@ -362,6 +364,7 @@ func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { } } respText := strings.Builder{} + toolResp := strings.Builder{} out: for { select { @@ -374,6 +377,10 @@ out: // audioStream.TextChan <- chunk extra.TTSTextChan <- chunk } + case toolChunk := <-openAIToolChan: + fmt.Fprint(tv, toolChunk) + toolResp.WriteString(toolChunk) + tv.ScrollToEnd() case <-streamDone: botRespMode = false if cfg.TTS_ENABLED { @@ -402,10 +409,11 @@ out: if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { logger.Warn("failed to update storage", "error", err, "name", activeChatName) } - findCall(respText.String(), tv) + // INFO: for completion only; openai has it's own tool struct + findCall(respText.String(), toolResp.String(), tv) } -func findCall(msg string, tv *tview.TextView) { +func findCall(msg, toolCall string, tv *tview.TextView) { fc := models.FuncCall{} jsStr := toolCallRE.FindString(msg) if jsStr == "" { -- cgit v1.2.3