From d7d432b8a1dbea9e18f78d835112fa074051f587 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Fri, 8 Aug 2025 13:03:37 +0300 Subject: Enha: /chat /completions tool calls to live in peace --- bot.go | 44 +++++++++++++++++++++++++++++--------------- llm.go | 6 ++++-- models/models.go | 21 +++++++++++++++------ tools.go | 20 ++++++++++---------- tui.go | 2 +- 5 files changed, 59 insertions(+), 34 deletions(-) diff --git a/bot.go b/bot.go index 72a0c44..b2de311 100644 --- a/bot.go +++ b/bot.go @@ -17,7 +17,6 @@ import ( "net/http" "os" "path" - "strconv" "strings" "time" @@ -44,6 +43,7 @@ var ( interruptResp = false ragger *rag.RAG chunkParser ChunkParser + lastToolCall *models.FuncCall //nolint:unused // TTS_ENABLED conditionally uses this orator extra.Orator asr extra.STT @@ -171,7 +171,7 @@ func sendMsgToLLM(body io.Reader) { req.Header.Add("Accept", "application/json") req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes))) + // req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes))) req.Header.Set("Accept-Encoding", "gzip") // nolint // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) @@ -253,6 +253,9 @@ func sendMsgToLLM(body io.Reader) { answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n") chunkChan <- answerText openAIToolChan <- chunk.ToolChunk + if chunk.FuncName != "" { + lastToolCall.Name = chunk.FuncName + } interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req interruptResp = false @@ -409,22 +412,32 @@ out: if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil { logger.Warn("failed to update storage", "error", err, "name", activeChatName) } - // INFO: for completion only; openai has it's own tool struct findCall(respText.String(), toolResp.String(), tv) } func findCall(msg, toolCall string, tv *tview.TextView) { - fc := models.FuncCall{} - jsStr := toolCallRE.FindString(msg) - if jsStr == "" { - return - } - prefix := "__tool_call__\n" - suffix := "\n__tool_call__" - jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) - if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { - logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) - return + fc := &models.FuncCall{} + if toolCall != "" { + openAIToolMap := make(map[string]string) + // respect tool call + if err := json.Unmarshal([]byte(toolCall), &openAIToolMap); err != nil { + logger.Error("failed to unmarshal openai tool call", "call", toolCall, "error", err) + return + } + lastToolCall.Args = openAIToolMap + fc = lastToolCall + } else { + jsStr := toolCallRE.FindString(msg) + if jsStr == "" { + return + } + prefix := "__tool_call__\n" + suffix := "\n__tool_call__" + jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) + if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { + logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr) + return + } } // call a func f, ok := fnMap[fc.Name] @@ -433,7 +446,7 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatRound(m, cfg.ToolRole, tv, false, false) return } - resp := f(fc.Args...) + resp := f(fc.Args) toolMsg := fmt.Sprintf("tool response: %+v", string(resp)) chatRound(toolMsg, cfg.ToolRole, tv, false, false) } @@ -550,6 +563,7 @@ func init() { logger.Error("failed to load chat", "error", err) return } + lastToolCall = &models.FuncCall{} lastChat := loadOldChatOrGetNew() chatBody = &models.ChatBody{ Model: "modelname", diff --git a/llm.go b/llm.go index c5f10ea..3c3085f 100644 --- a/llm.go +++ b/llm.go @@ -142,8 +142,10 @@ func (op OpenAIer) ParseChunk(data []byte) (*models.TextChunk, error) { return nil, err } resp := &models.TextChunk{ - Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content, - ToolChunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0].Function.Arguments, + Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content, + } + if len(llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls) > 0 { + resp.ToolChunk = llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0].Function.Arguments } if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" { if resp.Chunk != "" { diff --git a/models/models.go b/models/models.go index 9ca12d9..69d812b 100644 --- a/models/models.go +++ b/models/models.go @@ -5,9 +5,14 @@ import ( "strings" ) +// type FuncCall struct { +// Name string `json:"name"` +// Args []string `json:"args"` +// } + type FuncCall struct { - Name string `json:"name"` - Args []string `json:"args"` + Name string `json:"name"` + Args map[string]string `json:"args"` } type LLMResp struct { @@ -30,11 +35,14 @@ type LLMResp struct { ID string `json:"id"` } +type ToolDeltaFunc struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + type ToolDeltaResp struct { - Index int `json:"index"` - Function struct { - Arguments string `json:"arguments"` - } `json:"function"` + Index int `json:"index"` + Function ToolDeltaFunc `json:"function"` } // for streaming @@ -63,6 +71,7 @@ type TextChunk struct { ToolChunk string Finished bool ToolResp bool + FuncName string } type RoleMsg struct { diff --git a/tools.go b/tools.go index 3b5fbf6..dee0577 100644 --- a/tools.go +++ b/tools.go @@ -46,7 +46,7 @@ To make a function call return a json object within __tool_call__ tags; __tool_call__ { "name":"recall", -"args": ["Adam's number"] +"args": {"topic": "Adam's number"} } __tool_call__ @@ -84,7 +84,7 @@ also: - some writing can be done without consideration of previous data; - others do; */ -func memorise(args ...string) []byte { +func memorise(args map[string]string) []byte { agent := cfg.AssistantRole if len(args) < 2 { msg := "not enough args to call memorise tool; need topic and data to remember" @@ -93,35 +93,35 @@ func memorise(args ...string) []byte { } memory := &models.Memory{ Agent: agent, - Topic: args[0], - Mind: args[1], + Topic: args["topic"], + Mind: args["data"], UpdatedAt: time.Now(), } if _, err := store.Memorise(memory); err != nil { logger.Error("failed to save memory", "err", err, "memoory", memory) return []byte("failed to save info") } - msg := "info saved under the topic:" + args[0] + msg := "info saved under the topic:" + args["topic"] return []byte(msg) } -func recall(args ...string) []byte { +func recall(args map[string]string) []byte { agent := cfg.AssistantRole if len(args) < 1 { logger.Warn("not enough args to call recall tool") return nil } - mind, err := store.Recall(agent, args[0]) + mind, err := store.Recall(agent, args["topic"]) if err != nil { msg := fmt.Sprintf("failed to recall; error: %v; args: %v", err, args) logger.Error(msg) return []byte(msg) } - answer := fmt.Sprintf("under the topic: %s is stored:\n%s", args[0], mind) + answer := fmt.Sprintf("under the topic: %s is stored:\n%s", args["topic"], mind) return []byte(answer) } -func recallTopics(args ...string) []byte { +func recallTopics(args map[string]string) []byte { agent := cfg.AssistantRole topics, err := store.RecallTopics(agent) if err != nil { @@ -134,7 +134,7 @@ func recallTopics(args ...string) []byte { // func fullMemoryLoad() {} -type fnSig func(...string) []byte +type fnSig func(map[string]string) []byte var fnMap = map[string]fnSig{ "recall": recall, diff --git a/tui.go b/tui.go index ee0e5e6..b91058c 100644 --- a/tui.go +++ b/tui.go @@ -558,7 +558,7 @@ func init() { if event.Key() == tcell.KeyF5 { // switch cfg.ShowSys cfg.ShowSys = !cfg.ShowSys - textView.SetText(chatToText(cfg.ShowSys)) + textView.SetText(chatToText(cfg.ShowSys)) // TODO: fix removing all new names colorText() } if event.Key() == tcell.KeyF6 { -- cgit v1.2.3