summaryrefslogtreecommitdiff
path: root/agent
diff options
context:
space:
mode:
Diffstat (limited to 'agent')
-rw-r--r--agent/pw_agent.go9
-rw-r--r--agent/pw_tools.go15
2 files changed, 3 insertions, 21 deletions
diff --git a/agent/pw_agent.go b/agent/pw_agent.go
index 2807331..787d411 100644
--- a/agent/pw_agent.go
+++ b/agent/pw_agent.go
@@ -57,15 +57,12 @@ func (a *PWAgent) setToolCallOnLastMessage(resp []byte, toolCallID string) {
if toolCallID == "" {
return
}
-
var genericResp map[string]interface{}
if err := json.Unmarshal(resp, &genericResp); err != nil {
return
}
-
var name string
var args map[string]string
-
if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
if firstChoice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := firstChoice["message"].(map[string]interface{}); ok {
@@ -74,19 +71,17 @@ func (a *PWAgent) setToolCallOnLastMessage(resp []byte, toolCallID string) {
if fn, ok := tc["function"].(map[string]interface{}); ok {
name, _ = fn["name"].(string)
argsStr, _ := fn["arguments"].(string)
- json.Unmarshal([]byte(argsStr), &args)
+ _ = json.Unmarshal([]byte(argsStr), &args)
}
}
}
}
}
}
-
if name == "" {
content, _ := genericResp["content"].(string)
name = extractToolNameFromText(content)
}
-
lastIdx := len(a.chatBody.Messages) - 1
if lastIdx >= 0 {
a.chatBody.Messages[lastIdx].ToolCallID = toolCallID
@@ -110,14 +105,12 @@ func extractToolNameFromText(text string) string {
jsStr = strings.TrimPrefix(jsStr, "__tool_call__")
jsStr = strings.TrimSuffix(jsStr, "__tool_call__")
jsStr = strings.TrimSpace(jsStr)
-
start := strings.Index(jsStr, "{")
end := strings.LastIndex(jsStr, "}")
if start == -1 || end == -1 || end <= start {
return ""
}
jsStr = jsStr[start : end+1]
-
var fc models.FuncCall
if err := json.Unmarshal([]byte(jsStr), &fc); err != nil {
return ""
diff --git a/agent/pw_tools.go b/agent/pw_tools.go
index 19fd130..d72e0f3 100644
--- a/agent/pw_tools.go
+++ b/agent/pw_tools.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"regexp"
+ "strconv"
"strings"
"gf-lt/models"
@@ -252,7 +253,6 @@ func findToolCall(resp []byte) (func() []byte, string, bool) {
if err := json.Unmarshal(resp, &genericResp); err != nil {
return findToolCallFromText(string(resp))
}
-
if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
if firstChoice, ok := choices[0].(map[string]interface{}); ok {
if message, ok := firstChoice["message"].(map[string]interface{}); ok {
@@ -268,11 +268,9 @@ func findToolCall(resp []byte) (func() []byte, string, bool) {
}
}
}
-
if content, ok := genericResp["content"].(string); ok {
return findToolCallFromText(content)
}
-
return findToolCallFromText(string(resp))
}
@@ -280,20 +278,17 @@ func parseOpenAIToolCall(toolCalls []interface{}) (func() []byte, string, bool)
if len(toolCalls) == 0 {
return nil, "", false
}
-
tc := toolCalls[0].(map[string]interface{})
id, _ := tc["id"].(string)
function, _ := tc["function"].(map[string]interface{})
name, _ := function["name"].(string)
argsStr, _ := function["arguments"].(string)
-
var args map[string]string
if err := json.Unmarshal([]byte(argsStr), &args); err != nil {
return func() []byte {
return []byte(fmt.Sprintf(`{"error": "failed to parse arguments: %v"}`, err))
}, id, true
}
-
return func() []byte {
fn, ok := pwToolMap[name]
if !ok {
@@ -308,12 +303,10 @@ func findToolCallFromText(text string) (func() []byte, string, bool) {
if jsStr == "" {
return nil, "", false
}
-
jsStr = strings.TrimSpace(jsStr)
jsStr = strings.TrimPrefix(jsStr, "__tool_call__")
jsStr = strings.TrimSuffix(jsStr, "__tool_call__")
jsStr = strings.TrimSpace(jsStr)
-
start := strings.Index(jsStr, "{")
end := strings.LastIndex(jsStr, "}")
if start == -1 || end == -1 || end <= start {
@@ -321,20 +314,16 @@ func findToolCallFromText(text string) (func() []byte, string, bool) {
return []byte(`{"error": "no valid JSON found in tool call"}`)
}, "", true
}
-
jsStr = jsStr[start : end+1]
-
var fc models.FuncCall
if err := json.Unmarshal([]byte(jsStr), &fc); err != nil {
return func() []byte {
return []byte(fmt.Sprintf(`{"error": "failed to parse tool call: %v}`, err))
}, "", true
}
-
if fc.ID == "" {
fc.ID = "call_" + generateToolCallID()
}
-
return func() []byte {
fn, ok := pwToolMap[fc.Name]
if !ok {
@@ -345,5 +334,5 @@ func findToolCallFromText(text string) (func() []byte, string, bool) {
}
func generateToolCallID() string {
- return fmt.Sprintf("%d", len(pwToolMap)%10000)
+ return strconv.Itoa(len(pwToolMap) % 10000)
}