summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-08-08 13:03:37 +0300
committerGrail Finder <wohilas@gmail.com>2025-08-08 13:03:37 +0300
commitd7d432b8a1dbea9e18f78d835112fa074051f587 (patch)
tree517a6e057aaf3fb5fd418690bfff1941d55c45bb
parent589dfdda3fa89ecc984530ce3bfcc58ee2fd851d (diff)
Enha: /chat /completions tool calls to live in peace
-rw-r--r--bot.go44
-rw-r--r--llm.go6
-rw-r--r--models/models.go21
-rw-r--r--tools.go20
-rw-r--r--tui.go2
5 files changed, 59 insertions, 34 deletions
diff --git a/bot.go b/bot.go
index 72a0c44..b2de311 100644
--- a/bot.go
+++ b/bot.go
@@ -17,7 +17,6 @@ import (
"net/http"
"os"
"path"
- "strconv"
"strings"
"time"
@@ -44,6 +43,7 @@ var (
interruptResp = false
ragger *rag.RAG
chunkParser ChunkParser
+ lastToolCall *models.FuncCall
//nolint:unused // TTS_ENABLED conditionally uses this
orator extra.Orator
asr extra.STT
@@ -171,7 +171,7 @@ func sendMsgToLLM(body io.Reader) {
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken())
- req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
+ // req.Header.Set("Content-Length", strconv.Itoa(len(bodyBytes)))
req.Header.Set("Accept-Encoding", "gzip")
// nolint
// resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
@@ -253,6 +253,9 @@ func sendMsgToLLM(body io.Reader) {
answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
chunkChan <- answerText
openAIToolChan <- chunk.ToolChunk
+ if chunk.FuncName != "" {
+ lastToolCall.Name = chunk.FuncName
+ }
interrupt:
if interruptResp { // read bytes, so it would not get into beginning of the next req
interruptResp = false
@@ -409,22 +412,32 @@ out:
if err := updateStorageChat(activeChatName, chatBody.Messages); err != nil {
logger.Warn("failed to update storage", "error", err, "name", activeChatName)
}
- // INFO: for completion only; openai has it's own tool struct
findCall(respText.String(), toolResp.String(), tv)
}
func findCall(msg, toolCall string, tv *tview.TextView) {
- fc := models.FuncCall{}
- jsStr := toolCallRE.FindString(msg)
- if jsStr == "" {
- return
- }
- prefix := "__tool_call__\n"
- suffix := "\n__tool_call__"
- jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
- if err := json.Unmarshal([]byte(jsStr), &fc); err != nil {
- logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr)
- return
+ fc := &models.FuncCall{}
+ if toolCall != "" {
+ openAIToolMap := make(map[string]string)
+ // respect tool call
+ if err := json.Unmarshal([]byte(toolCall), &openAIToolMap); err != nil {
+ logger.Error("failed to unmarshal openai tool call", "call", toolCall, "error", err)
+ return
+ }
+ lastToolCall.Args = openAIToolMap
+ fc = lastToolCall
+ } else {
+ jsStr := toolCallRE.FindString(msg)
+ if jsStr == "" {
+ return
+ }
+ prefix := "__tool_call__\n"
+ suffix := "\n__tool_call__"
+ jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
+ if err := json.Unmarshal([]byte(jsStr), &fc); err != nil {
+ logger.Error("failed to unmarshal tool call", "error", err, "json_string", jsStr)
+ return
+ }
}
// call a func
f, ok := fnMap[fc.Name]
@@ -433,7 +446,7 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
chatRound(m, cfg.ToolRole, tv, false, false)
return
}
- resp := f(fc.Args...)
+ resp := f(fc.Args)
toolMsg := fmt.Sprintf("tool response: %+v", string(resp))
chatRound(toolMsg, cfg.ToolRole, tv, false, false)
}
@@ -550,6 +563,7 @@ func init() {
logger.Error("failed to load chat", "error", err)
return
}
+ lastToolCall = &models.FuncCall{}
lastChat := loadOldChatOrGetNew()
chatBody = &models.ChatBody{
Model: "modelname",
diff --git a/llm.go b/llm.go
index c5f10ea..3c3085f 100644
--- a/llm.go
+++ b/llm.go
@@ -142,8 +142,10 @@ func (op OpenAIer) ParseChunk(data []byte) (*models.TextChunk, error) {
return nil, err
}
resp := &models.TextChunk{
- Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content,
- ToolChunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0].Function.Arguments,
+ Chunk: llmchunk.Choices[len(llmchunk.Choices)-1].Delta.Content,
+ }
+ if len(llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls) > 0 {
+ resp.ToolChunk = llmchunk.Choices[len(llmchunk.Choices)-1].Delta.ToolCalls[0].Function.Arguments
}
if llmchunk.Choices[len(llmchunk.Choices)-1].FinishReason == "stop" {
if resp.Chunk != "" {
diff --git a/models/models.go b/models/models.go
index 9ca12d9..69d812b 100644
--- a/models/models.go
+++ b/models/models.go
@@ -5,9 +5,14 @@ import (
"strings"
)
+// type FuncCall struct {
+// Name string `json:"name"`
+// Args []string `json:"args"`
+// }
+
type FuncCall struct {
- Name string `json:"name"`
- Args []string `json:"args"`
+ Name string `json:"name"`
+ Args map[string]string `json:"args"`
}
type LLMResp struct {
@@ -30,11 +35,14 @@ type LLMResp struct {
ID string `json:"id"`
}
+type ToolDeltaFunc struct {
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+}
+
type ToolDeltaResp struct {
- Index int `json:"index"`
- Function struct {
- Arguments string `json:"arguments"`
- } `json:"function"`
+ Index int `json:"index"`
+ Function ToolDeltaFunc `json:"function"`
}
// for streaming
@@ -63,6 +71,7 @@ type TextChunk struct {
ToolChunk string
Finished bool
ToolResp bool
+ FuncName string
}
type RoleMsg struct {
diff --git a/tools.go b/tools.go
index 3b5fbf6..dee0577 100644
--- a/tools.go
+++ b/tools.go
@@ -46,7 +46,7 @@ To make a function call return a json object within __tool_call__ tags;
__tool_call__
{
"name":"recall",
-"args": ["Adam's number"]
+"args": {"topic": "Adam's number"}
}
__tool_call__
</example_request>
@@ -84,7 +84,7 @@ also:
- some writing can be done without consideration of previous data;
- others do;
*/
-func memorise(args ...string) []byte {
+func memorise(args map[string]string) []byte {
agent := cfg.AssistantRole
if len(args) < 2 {
msg := "not enough args to call memorise tool; need topic and data to remember"
@@ -93,35 +93,35 @@ func memorise(args ...string) []byte {
}
memory := &models.Memory{
Agent: agent,
- Topic: args[0],
- Mind: args[1],
+ Topic: args["topic"],
+ Mind: args["data"],
UpdatedAt: time.Now(),
}
if _, err := store.Memorise(memory); err != nil {
logger.Error("failed to save memory", "err", err, "memoory", memory)
return []byte("failed to save info")
}
- msg := "info saved under the topic:" + args[0]
+ msg := "info saved under the topic:" + args["topic"]
return []byte(msg)
}
-func recall(args ...string) []byte {
+func recall(args map[string]string) []byte {
agent := cfg.AssistantRole
if len(args) < 1 {
logger.Warn("not enough args to call recall tool")
return nil
}
- mind, err := store.Recall(agent, args[0])
+ mind, err := store.Recall(agent, args["topic"])
if err != nil {
msg := fmt.Sprintf("failed to recall; error: %v; args: %v", err, args)
logger.Error(msg)
return []byte(msg)
}
- answer := fmt.Sprintf("under the topic: %s is stored:\n%s", args[0], mind)
+ answer := fmt.Sprintf("under the topic: %s is stored:\n%s", args["topic"], mind)
return []byte(answer)
}
-func recallTopics(args ...string) []byte {
+func recallTopics(args map[string]string) []byte {
agent := cfg.AssistantRole
topics, err := store.RecallTopics(agent)
if err != nil {
@@ -134,7 +134,7 @@ func recallTopics(args ...string) []byte {
// func fullMemoryLoad() {}
-type fnSig func(...string) []byte
+type fnSig func(map[string]string) []byte
var fnMap = map[string]fnSig{
"recall": recall,
diff --git a/tui.go b/tui.go
index ee0e5e6..b91058c 100644
--- a/tui.go
+++ b/tui.go
@@ -558,7 +558,7 @@ func init() {
if event.Key() == tcell.KeyF5 {
// switch cfg.ShowSys
cfg.ShowSys = !cfg.ShowSys
- textView.SetText(chatToText(cfg.ShowSys))
+ textView.SetText(chatToText(cfg.ShowSys)) // TODO: fix removing all new names
colorText()
}
if event.Key() == tcell.KeyF6 {