diff options
Diffstat (limited to 'agent')
| -rw-r--r-- | agent/pw_agent.go | 9 | ||||
| -rw-r--r-- | agent/pw_tools.go | 15 |
2 files changed, 3 insertions, 21 deletions
diff --git a/agent/pw_agent.go b/agent/pw_agent.go index 2807331..787d411 100644 --- a/agent/pw_agent.go +++ b/agent/pw_agent.go @@ -57,15 +57,12 @@ func (a *PWAgent) setToolCallOnLastMessage(resp []byte, toolCallID string) { if toolCallID == "" { return } - var genericResp map[string]interface{} if err := json.Unmarshal(resp, &genericResp); err != nil { return } - var name string var args map[string]string - if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { if firstChoice, ok := choices[0].(map[string]interface{}); ok { if message, ok := firstChoice["message"].(map[string]interface{}); ok { @@ -74,19 +71,17 @@ func (a *PWAgent) setToolCallOnLastMessage(resp []byte, toolCallID string) { if fn, ok := tc["function"].(map[string]interface{}); ok { name, _ = fn["name"].(string) argsStr, _ := fn["arguments"].(string) - json.Unmarshal([]byte(argsStr), &args) + _ = json.Unmarshal([]byte(argsStr), &args) } } } } } } - if name == "" { content, _ := genericResp["content"].(string) name = extractToolNameFromText(content) } - lastIdx := len(a.chatBody.Messages) - 1 if lastIdx >= 0 { a.chatBody.Messages[lastIdx].ToolCallID = toolCallID @@ -110,14 +105,12 @@ func extractToolNameFromText(text string) string { jsStr = strings.TrimPrefix(jsStr, "__tool_call__") jsStr = strings.TrimSuffix(jsStr, "__tool_call__") jsStr = strings.TrimSpace(jsStr) - start := strings.Index(jsStr, "{") end := strings.LastIndex(jsStr, "}") if start == -1 || end == -1 || end <= start { return "" } jsStr = jsStr[start : end+1] - var fc models.FuncCall if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { return "" diff --git a/agent/pw_tools.go b/agent/pw_tools.go index 19fd130..d72e0f3 100644 --- a/agent/pw_tools.go +++ b/agent/pw_tools.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "regexp" + "strconv" "strings" "gf-lt/models" @@ -252,7 +253,6 @@ func findToolCall(resp []byte) (func() []byte, string, bool) { if err := json.Unmarshal(resp, &genericResp); err != nil { return findToolCallFromText(string(resp)) } - if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { if firstChoice, ok := choices[0].(map[string]interface{}); ok { if message, ok := firstChoice["message"].(map[string]interface{}); ok { @@ -268,11 +268,9 @@ func findToolCall(resp []byte) (func() []byte, string, bool) { } } } - if content, ok := genericResp["content"].(string); ok { return findToolCallFromText(content) } - return findToolCallFromText(string(resp)) } @@ -280,20 +278,17 @@ func parseOpenAIToolCall(toolCalls []interface{}) (func() []byte, string, bool) if len(toolCalls) == 0 { return nil, "", false } - tc := toolCalls[0].(map[string]interface{}) id, _ := tc["id"].(string) function, _ := tc["function"].(map[string]interface{}) name, _ := function["name"].(string) argsStr, _ := function["arguments"].(string) - var args map[string]string if err := json.Unmarshal([]byte(argsStr), &args); err != nil { return func() []byte { return []byte(fmt.Sprintf(`{"error": "failed to parse arguments: %v"}`, err)) }, id, true } - return func() []byte { fn, ok := pwToolMap[name] if !ok { @@ -308,12 +303,10 @@ func findToolCallFromText(text string) (func() []byte, string, bool) { if jsStr == "" { return nil, "", false } - jsStr = strings.TrimSpace(jsStr) jsStr = strings.TrimPrefix(jsStr, "__tool_call__") jsStr = strings.TrimSuffix(jsStr, "__tool_call__") jsStr = strings.TrimSpace(jsStr) - start := strings.Index(jsStr, "{") end := strings.LastIndex(jsStr, "}") if start == -1 || end == -1 || end <= start { @@ -321,20 +314,16 @@ func findToolCallFromText(text string) (func() []byte, string, bool) { return []byte(`{"error": "no valid JSON found in tool call"}`) }, "", true } - jsStr = jsStr[start : end+1] - var fc models.FuncCall if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { return func() []byte { return []byte(fmt.Sprintf(`{"error": "failed to parse tool call: %v}`, err)) }, "", true } - if fc.ID == "" { fc.ID = "call_" + generateToolCallID() } - return func() []byte { fn, ok := pwToolMap[fc.Name] if !ok { @@ -345,5 +334,5 @@ func findToolCallFromText(text string) (func() []byte, string, bool) { } func generateToolCallID() string { - return fmt.Sprintf("%d", len(pwToolMap)%10000) + return strconv.Itoa(len(pwToolMap) % 10000) } |
