From d7d432b8a1dbea9e18f78d835112fa074051f587 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Fri, 8 Aug 2025 13:03:37 +0300 Subject: Enha: /chat /completions tool calls to live in peace --- bot.go | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) (limited to 'bot.go') diff --git a/bot.go b/bot.go index 72a0c44..b2de311 100644 --- a/bot.go +++ b/bot.go @@ -17,7 +17,6 @@ import ( "net/http" "os" "path" - "strconv" "strings" "time" @@ -44,6 +43,7 @@ var ( interruptResp = false ragger *rag.RAG chunkParser ChunkParser + lastToolCall *models.FuncCall //nolint:unused // TTS_ENABLED conditionally uses this orator extra.Orator asr extra.STT @@ -171,7 +171,7 @@ func sendMsgToLLM(body io.Reader) { req.Header.Add("Accept", "application/json") req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes))) + // req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes))) req.Header.Set("Accept-Encoding", "gzip") // nolint // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) @@ -253,6 +253,9 @@ func sendMsgToLLM(body io.Reader) { answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n") chunkChan <- answerText openAIToolChan <- chunk.ToolChunk + if chunk.FuncName != "" { + lastToolCall.Name = chunk.FuncName + } interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req interruptResp = false @@ -409,22 +412,32 @@ out: if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { logger.Warn("failed to update storage", "error", err, "name", activeChatName) } - // INFO: for completion only; openai has it's own tool struct findCall(respText.String(), toolResp.String(), tv) } func findCall(msg, toolCall string, tv *tview.TextView) { - fc := models.FuncCall{} - jsStr := toolCallRE.FindString(msg) - if jsStr == "" { - return - } - prefix := "__tool_call__\n" - suffix := "\n__tool_call__" - jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) - if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { - logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) - return + fc := &models.FuncCall{} + if toolCall != "" { + openAIToolMap := make(map[string]string) + // respect tool call + if err := json.Unmarshal([]byte(toolCall), &openAIToolMap); err != nil { + logger.Error("failed to unmarshal openai tool call", "call", toolCall, "error", err) + return + } + lastToolCall.Args = openAIToolMap + fc = lastToolCall + } else { + jsStr := toolCallRE.FindString(msg) + if jsStr == "" { + return + } + prefix := "__tool_call__\n" + suffix := "\n__tool_call__" + jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) + if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { + logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) + return + } } // call a func f, ok := fnMap[fc.Name] @@ -433,7 +446,7 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatRound(m, cfg.ToolRole, tv, false, false) return } - resp := f(fc.Args...) + resp := f(fc.Args) toolMsg := fmt.Sprintf("tool response: %+v", string(resp)) chatRound(toolMsg, cfg.ToolRole, tv, false, false) } @@ -550,6 +563,7 @@ func init() { logger.Error("failed to load chat", "error", err) return } + lastToolCall = &models.FuncCall{} lastChat := loadOldChatOrGetNew() chatBody = &models.ChatBody{ Model: "modelname", -- cgit v1.2.3