From c2c90f6d2b766bbba30c8ea8087f799a6c21f525 Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Mon, 9 Mar 2026 08:50:33 +0300 Subject: Enha: pw agent --- agent/pw_agent.go | 102 ++++++++++++++-- agent/pw_tools.go | 349 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ agent/request.go | 14 +++ agent/webagent.go | 4 +- 4 files changed, 459 insertions(+), 10 deletions(-) create mode 100644 agent/pw_tools.go (limited to 'agent') diff --git a/agent/pw_agent.go b/agent/pw_agent.go index 8c1c2bf..2807331 100644 --- a/agent/pw_agent.go +++ b/agent/pw_agent.go @@ -1,5 +1,11 @@ package agent +import ( + "encoding/json" + "gf-lt/models" + "strings" +) + // PWAgent: is AgenterA type agent (enclosed with tool chaining) // sysprompt explain tools and how to plan for execution type PWAgent struct { @@ -7,11 +13,16 @@ type PWAgent struct { sysprompt string } -// NewWebAgentB creates a WebAgentB that uses the given formatting function +// NewPWAgent creates a PWAgent with the given client and system prompt func NewPWAgent(client *AgentClient, sysprompt string) *PWAgent { return &PWAgent{AgentClient: client, sysprompt: sysprompt} } +// SetTools sets the tools available to the agent +func (a *PWAgent) SetTools(tools []models.Tool) { + a.tools = tools +} + func (a *PWAgent) ProcessTask(task string) []byte { req, err := a.FormFirstMsg(a.sysprompt, task) if err != nil { @@ -25,16 +36,91 @@ func (a *PWAgent) ProcessTask(task string) []byte { a.Log().Error("failed to process the request", "error", err) return []byte("failed to process the request; err: " + err.Error()) } - toolCall, hasToolCall := findToolCall(resp) + execTool, toolCallID, hasToolCall := findToolCall(resp) if !hasToolCall { return resp } - // check resp for tool calls - // make tool call - // add tool call resp to body - // send new request too lmm - tooResp := toolCall(resp) - req, err = a.FormMsg(toolResp) + + a.setToolCallOnLastMessage(resp, toolCallID) + + toolResp := string(execTool()) + req, err = a.FormMsgWithToolCallID(toolResp, toolCallID) + if err != nil { + a.Log().Error("failed to form next message", "error", err) + return []byte("failed to form next message; err: " + err.Error()) + } } return nil } + +func (a *PWAgent) setToolCallOnLastMessage(resp []byte, toolCallID string) { + if toolCallID == "" { + return + } + + var genericResp map[string]interface{} + if err := json.Unmarshal(resp, &genericResp); err != nil { + return + } + + var name string + var args map[string]string + + if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { + if firstChoice, ok := choices[0].(map[string]interface{}); ok { + if message, ok := firstChoice["message"].(map[string]interface{}); ok { + if toolCalls, ok := message["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { + if tc, ok := toolCalls[0].(map[string]interface{}); ok { + if fn, ok := tc["function"].(map[string]interface{}); ok { + name, _ = fn["name"].(string) + argsStr, _ := fn["arguments"].(string) + json.Unmarshal([]byte(argsStr), &args) + } + } + } + } + } + } + + if name == "" { + content, _ := genericResp["content"].(string) + name = extractToolNameFromText(content) + } + + lastIdx := len(a.chatBody.Messages) - 1 + if lastIdx >= 0 { + a.chatBody.Messages[lastIdx].ToolCallID = toolCallID + if name != "" { + argsJSON, _ := json.Marshal(args) + a.chatBody.Messages[lastIdx].ToolCall = &models.ToolCall{ + ID: toolCallID, + Name: name, + Args: string(argsJSON), + } + } + } +} + +func extractToolNameFromText(text string) string { + jsStr := toolCallRE.FindString(text) + if jsStr == "" { + return "" + } + jsStr = strings.TrimSpace(jsStr) + jsStr = strings.TrimPrefix(jsStr, "__tool_call__") + jsStr = strings.TrimSuffix(jsStr, "__tool_call__") + jsStr = strings.TrimSpace(jsStr) + + start := strings.Index(jsStr, "{") + end := strings.LastIndex(jsStr, "}") + if start == -1 || end == -1 || end <= start { + return "" + } + jsStr = jsStr[start : end+1] + + var fc models.FuncCall + if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { + return "" + } + return fc.Name +} diff --git a/agent/pw_tools.go b/agent/pw_tools.go new file mode 100644 index 0000000..19fd130 --- /dev/null +++ b/agent/pw_tools.go @@ -0,0 +1,349 @@ +package agent + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "gf-lt/models" +) + +type ToolFunc func(map[string]string) []byte + +var pwToolMap = make(map[string]ToolFunc) + +func RegisterPWTool(name string, fn ToolFunc) { + pwToolMap[name] = fn +} + +func GetPWTools() []models.Tool { + return pwTools +} + +var pwTools = []models.Tool{ + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_start", + Description: "Start a Playwright browser instance. Must be called first before any other browser automation. Uses headless mode by default.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{}, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_stop", + Description: "Stop the Playwright browser instance. Call when done with browser automation.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{}, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_is_running", + Description: "Check if Playwright browser is currently running.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{}, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_navigate", + Description: "Navigate to a URL in the browser.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"url"}, + Properties: map[string]models.ToolArgProps{ + "url": {Type: "string", Description: "URL to navigate to"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_click", + Description: "Click on an element on the current webpage. Use 'index' for multiple matches (default 0).", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"selector"}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector for the element"}, + "index": {Type: "integer", Description: "Index for multiple matches (default 0)"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_fill", + Description: "Type text into an input field. Use 'index' for multiple matches (default 0).", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"selector", "text"}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector for the input element"}, + "text": {Type: "string", Description: "Text to type into the field"}, + "index": {Type: "integer", Description: "Index for multiple matches (default 0)"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_extract_text", + Description: "Extract text content from the page or specific elements. Use selector 'body' for all page text.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector (default 'body' for all page text)"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_screenshot", + Description: "Take a screenshot of the page or a specific element. Returns a file path to the image.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector for element to screenshot"}, + "full_page": {Type: "boolean", Description: "Capture full page (default false)"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_screenshot_and_view", + Description: "Take a screenshot and return the image for viewing. Use to visually verify page state.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector for element to screenshot"}, + "full_page": {Type: "boolean", Description: "Capture full page (default false)"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_wait_for_selector", + Description: "Wait for an element to appear on the page before proceeding.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"selector"}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector to wait for"}, + "timeout": {Type: "integer", Description: "Timeout in milliseconds (default 30000)"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_drag", + Description: "Drag the mouse from point (x1,y1) to (x2,y2).", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"x1", "y1", "x2", "y2"}, + Properties: map[string]models.ToolArgProps{ + "x1": {Type: "number", Description: "Starting X coordinate"}, + "y1": {Type: "number", Description: "Starting Y coordinate"}, + "x2": {Type: "number", Description: "Ending X coordinate"}, + "y2": {Type: "number", Description: "Ending Y coordinate"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_click_at", + Description: "Click at specific X,Y coordinates on the page.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{"x", "y"}, + Properties: map[string]models.ToolArgProps{ + "x": {Type: "number", Description: "X coordinate"}, + "y": {Type: "number", Description: "Y coordinate"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_get_html", + Description: "Get the HTML content of the page or a specific element.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector (default 'body')"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_get_dom", + Description: "Get a structured DOM representation with tag, attributes, text, and children.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "selector": {Type: "string", Description: "CSS selector (default 'body')"}, + }, + }, + }, + }, + { + Type: "function", + Function: models.ToolFunc{ + Name: "pw_search_elements", + Description: "Search for elements by text content or CSS selector.", + Parameters: models.ToolFuncParams{ + Type: "object", + Required: []string{}, + Properties: map[string]models.ToolArgProps{ + "text": {Type: "string", Description: "Text content to search for"}, + "selector": {Type: "string", Description: "CSS selector to search for"}, + }, + }, + }, + }, +} + +var toolCallRE = regexp.MustCompile(`__tool_call__(.+?)__tool_call__`) + +type ParsedToolCall struct { + ID string + Name string + Args map[string]string +} + +func findToolCall(resp []byte) (func() []byte, string, bool) { + var genericResp map[string]interface{} + if err := json.Unmarshal(resp, &genericResp); err != nil { + return findToolCallFromText(string(resp)) + } + + if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { + if firstChoice, ok := choices[0].(map[string]interface{}); ok { + if message, ok := firstChoice["message"].(map[string]interface{}); ok { + if toolCalls, ok := message["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { + return parseOpenAIToolCall(toolCalls) + } + if content, ok := message["content"].(string); ok { + return findToolCallFromText(content) + } + } + if text, ok := firstChoice["text"].(string); ok { + return findToolCallFromText(text) + } + } + } + + if content, ok := genericResp["content"].(string); ok { + return findToolCallFromText(content) + } + + return findToolCallFromText(string(resp)) +} + +func parseOpenAIToolCall(toolCalls []interface{}) (func() []byte, string, bool) { + if len(toolCalls) == 0 { + return nil, "", false + } + + tc := toolCalls[0].(map[string]interface{}) + id, _ := tc["id"].(string) + function, _ := tc["function"].(map[string]interface{}) + name, _ := function["name"].(string) + argsStr, _ := function["arguments"].(string) + + var args map[string]string + if err := json.Unmarshal([]byte(argsStr), &args); err != nil { + return func() []byte { + return []byte(fmt.Sprintf(`{"error": "failed to parse arguments: %v"}`, err)) + }, id, true + } + + return func() []byte { + fn, ok := pwToolMap[name] + if !ok { + return []byte(fmt.Sprintf(`{"error": "tool %s not found"}`, name)) + } + return fn(args) + }, id, true +} + +func findToolCallFromText(text string) (func() []byte, string, bool) { + jsStr := toolCallRE.FindString(text) + if jsStr == "" { + return nil, "", false + } + + jsStr = strings.TrimSpace(jsStr) + jsStr = strings.TrimPrefix(jsStr, "__tool_call__") + jsStr = strings.TrimSuffix(jsStr, "__tool_call__") + jsStr = strings.TrimSpace(jsStr) + + start := strings.Index(jsStr, "{") + end := strings.LastIndex(jsStr, "}") + if start == -1 || end == -1 || end <= start { + return func() []byte { + return []byte(`{"error": "no valid JSON found in tool call"}`) + }, "", true + } + + jsStr = jsStr[start : end+1] + + var fc models.FuncCall + if err := json.Unmarshal([]byte(jsStr), &fc); err != nil { + return func() []byte { + return []byte(fmt.Sprintf(`{"error": "failed to parse tool call: %v}`, err)) + }, "", true + } + + if fc.ID == "" { + fc.ID = "call_" + generateToolCallID() + } + + return func() []byte { + fn, ok := pwToolMap[fc.Name] + if !ok { + return []byte(fmt.Sprintf(`{"error": "tool %s not found"}`, fc.Name)) + } + return fn(fc.Args) + }, fc.ID, true +} + +func generateToolCallID() string { + return fmt.Sprintf("%d", len(pwToolMap)%10000) +} diff --git a/agent/request.go b/agent/request.go index 4ca619d..754f16e 100644 --- a/agent/request.go +++ b/agent/request.go @@ -80,6 +80,20 @@ func (ag *AgentClient) FormMsg(msg string) (io.Reader, error) { return bytes.NewReader(b), nil } +func (ag *AgentClient) FormMsgWithToolCallID(msg, toolCallID string) (io.Reader, error) { + m := models.RoleMsg{ + Role: "tool", + Content: msg, + ToolCallID: toolCallID, + } + ag.chatBody.Messages = append(ag.chatBody.Messages, m) + b, err := ag.buildRequest() + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} + // buildRequest creates the appropriate LLM request based on the current API endpoint. func (ag *AgentClient) buildRequest() ([]byte, error) { isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(ag.cfg.CurrentAPI) diff --git a/agent/webagent.go b/agent/webagent.go index ff6cd86..11e9014 100644 --- a/agent/webagent.go +++ b/agent/webagent.go @@ -17,8 +17,8 @@ func NewWebAgentB(client *AgentClient, sysprompt string) *WebAgentB { // Process applies the formatting function to raw output func (a *WebAgentB) Process(args map[string]string, rawOutput []byte) []byte { - msg, err := a.FormMsg(a.sysprompt, - fmt.Sprintf("request:\n%+v\ntool response:\n%v", args, string(rawOutput))) + msg, err := a.FormMsg( + fmt.Sprintf("%s\n\nrequest:\n%+v\ntool response:\n%v", a.sysprompt, args, string(rawOutput))) if err != nil { a.Log().Error("failed to process the request", "error", err) return []byte("failed to process the request; err: " + err.Error()) -- cgit v1.2.3