diff options
Diffstat (limited to 'bot.go')
-rw-r--r-- | bot.go | 44 |
1 files changed, 29 insertions, 15 deletions
@@ -17,7 +17,6 @@ import ( "net/http" "os" "path" - "strconv" "strings" "time" @@ -44,6 +43,7 @@ var ( interruptResp = false ragger *rag.RAG chunkParser ChunkParser + lastToolCall *models.FuncCall //nolint:unused // TTS_ENABLED conditionally uses this orator extra.Orator asr extra.STT @@ -171,7 +171,7 @@ func sendMsgToLLM(body io.Reader) { req.Header.Add("Accept", "application/json") req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes))) + // req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes))) req.Header.Set("Accept-Encoding", "gzip") // nolint // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) @@ -253,6 +253,9 @@ func sendMsgToLLM(body io.Reader) { answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n") chunkChan <- answerText openAIToolChan <- chunk.ToolChunk + if chunk.FuncName != "" { + lastToolCall.Name = chunk.FuncName + } interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req interruptResp = false @@ -409,22 +412,32 @@ out: if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { logger.Warn("failed to update storage", "error", err, "name", activeChatName) } - // INFO: for completion only; openai has it's own tool struct findCall(respText.String(), toolResp.String(), tv) } func findCall(msg, toolCall string, tv *tview.TextView) { - fc := models.FuncCall{} - jsStr := toolCallRE.FindString(msg) - if jsStr == "" { - return - } - prefix := "__tool_call__\n" - suffix := "\n__tool_call__" - jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) - if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { - logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) - return + fc := &models.FuncCall{} + if toolCall != "" { + openAIToolMap := make(map[string]string) + // respect tool call + if err := json.Unmarshal([]byte(toolCall), &openAIToolMap); err != nil { + logger.Error("failed to unmarshal openai tool call", "call", toolCall, "error", err) + return + } + lastToolCall.Args = openAIToolMap + fc = lastToolCall + } else { + jsStr := toolCallRE.FindString(msg) + if jsStr == "" { + return + } + prefix := "__tool_call__\n" + suffix := "\n__tool_call__" + jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) + if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { + logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) + return + } } // call a func f, ok := fnMap[fc.Name] @@ -433,7 +446,7 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatRound(m, cfg.ToolRole, tv, false, false) return } - resp := f(fc.Args...) + resp := f(fc.Args) toolMsg := fmt.Sprintf("tool response: %+v", string(resp)) chatRound(toolMsg, cfg.ToolRole, tv, false, false) } @@ -550,6 +563,7 @@ func init() { logger.Error("failed to load chat", "error", err) return } + lastToolCall = &models.FuncCall{} lastChat := loadOldChatOrGetNew() chatBody = &models.ChatBody{ Model: "modelname", |