diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-12-19 15:39:55 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-12-19 15:39:55 +0300 |
| commit | f779f039745f97f08f25967214d07716ce213326 (patch) | |
| tree | 9868dfe5e77845e27d4fe8b36f993a5fade2ea7b | |
| parent | a875abcf198dd2f85c518f8bf2c599db66d3e69f (diff) | |
Enha: agent request builder
| -rw-r--r-- | agent/agent.go | 10 | ||||
| -rw-r--r-- | agent/request.go | 192 | ||||
| -rw-r--r-- | bot.go | 72 | ||||
| -rw-r--r-- | bot_test.go | 134 | ||||
| -rw-r--r-- | tools.go | 52 |
5 files changed, 430 insertions, 30 deletions
diff --git a/agent/agent.go b/agent/agent.go index 5ad1ef1..8824ecb 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -33,3 +33,13 @@ func RegisterB(toolName string, a AgenterB) { func RegisterA(toolNames []string, a AgenterA) { RegistryA[a] = toolNames } + +// Get returns the agent registered for the given tool name, or nil if none. +func Get(toolName string) AgenterB { + return RegistryB[toolName] +} + +// Register is a convenience wrapper for RegisterB. +func Register(toolName string, a AgenterB) { + RegisterB(toolName, a) +} diff --git a/agent/request.go b/agent/request.go index 2d557ac..bb4a80d 100644 --- a/agent/request.go +++ b/agent/request.go @@ -3,15 +3,32 @@ package agent import ( "bytes" "encoding/json" + "fmt" "gf-lt/config" "gf-lt/models" "io" "log/slog" "net/http" + "strings" ) var httpClient = &http.Client{} +var defaultProps = map[string]float32{ + "temperature": 0.8, + "dry_multiplier": 0.0, + "min_p": 0.05, + "n_predict": -1.0, +} + +func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) { + isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions") + isChat = strings.Contains(api, "/chat/completions") + isDeepSeek = strings.Contains(api, "deepseek.com") + isOpenRouter = strings.Contains(api, "openrouter.ai") + return +} + type AgentClient struct { cfg *config.Config getToken func() string @@ -31,38 +48,185 @@ func (ag *AgentClient) Log() *slog.Logger { } func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) { - agentConvo := []models.RoleMsg{ + b, err := ag.buildRequest(sysprompt, msg) + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} + +// buildRequest creates the appropriate LLM request based on the current API endpoint. +func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) { + api := ag.cfg.CurrentAPI + model := ag.cfg.CurrentModel + messages := []models.RoleMsg{ {Role: "system", Content: sysprompt}, {Role: "user", Content: msg}, } - agentChat := &models.ChatBody{ - Model: ag.cfg.CurrentModel, - Stream: true, - Messages: agentConvo, + + // Determine API type + isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api) + ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter) + + // Build prompt for completion endpoints + if isCompletion { + var sb strings.Builder + for _, m := range messages { + sb.WriteString(m.ToPrompt()) + sb.WriteString("\n") + } + prompt := strings.TrimSpace(sb.String()) + + if isDeepSeek { + // DeepSeek completion + req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } else if isOpenRouter { + // OpenRouter completion + req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } else { + // Assume llama.cpp completion + req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } } - b, err := json.Marshal(agentChat) - if err != nil { - ag.log.Error("failed to form agent msg", "error", err) - return nil, err + + // Chat completions endpoints + if isChat || !isCompletion { + chatBody := &models.ChatBody{ + Model: model, + Stream: false, // Agents don't need streaming + Messages: messages, + } + + if isDeepSeek { + // DeepSeek chat + req := models.NewDSChatReq(*chatBody) + return json.Marshal(req) + } else if isOpenRouter { + // OpenRouter chat + req := models.NewOpenRouterChatReq(*chatBody, defaultProps) + return json.Marshal(req) + } else { + // Assume llama.cpp chat (OpenAI format) + req := models.OpenAIReq{ + ChatBody: chatBody, + Tools: nil, + } + return json.Marshal(req) + } } - return bytes.NewReader(b), nil + + // Fallback (should not reach here) + ag.log.Warn("unknown API, using default chat completions format", "api", api) + chatBody := &models.ChatBody{ + Model: model, + Stream: false, // Agents don't need streaming + Messages: messages, + } + return json.Marshal(chatBody) } func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { - req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, body) + // Read the body for debugging (but we need to recreate it for the request) + bodyBytes, err := io.ReadAll(body) + if err != nil { + ag.log.Error("failed to read request body", "error", err) + return nil, err + } + + req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes)) if err != nil { - ag.log.Error("llamacpp api", "error", err) + ag.log.Error("failed to create request", "error", err) return nil, err } req.Header.Add("Accept", "application/json") req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+ag.getToken()) req.Header.Set("Accept-Encoding", "gzip") + + ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)])) + resp, err := httpClient.Do(req) if err != nil { - ag.log.Error("llamacpp api", "error", err) + ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI) return nil, err } defer resp.Body.Close() - return io.ReadAll(resp.Body) + + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + ag.log.Error("failed to read response", "error", err) + return nil, err + } + + if resp.StatusCode >= 400 { + ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)])) + return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)])) + } + + // Parse response and extract text content + text, err := extractTextFromResponse(responseBytes) + if err != nil { + ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)])) + // Return raw response as fallback + return responseBytes, nil + } + + return []byte(text), nil +} + +// extractTextFromResponse parses common LLM response formats and extracts the text content. +func extractTextFromResponse(data []byte) (string, error) { + // Try to parse as generic JSON first + var genericResp map[string]interface{} + if err := json.Unmarshal(data, &genericResp); err != nil { + // Not JSON, return as string + return string(data), nil + } + + // Check for OpenAI chat completion format + if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { + if firstChoice, ok := choices[0].(map[string]interface{}); ok { + // Chat completion: choices[0].message.content + if message, ok := firstChoice["message"].(map[string]interface{}); ok { + if content, ok := message["content"].(string); ok { + return content, nil + } + } + // Completion: choices[0].text + if text, ok := firstChoice["text"].(string); ok { + return text, nil + } + // Delta format for streaming (should not happen with stream: false) + if delta, ok := firstChoice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + return content, nil + } + } + } + } + + // Check for llama.cpp completion format + if content, ok := genericResp["content"].(string); ok { + return content, nil + } + + // Unknown format, return pretty-printed JSON + prettyJSON, err := json.MarshalIndent(genericResp, "", " ") + if err != nil { + return string(data), nil + } + return string(prettyJSON), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b } @@ -6,6 +6,7 @@ import ( "context" "encoding/json" "fmt" + "strconv" "gf-lt/config" "gf-lt/extra" "gf-lt/models" @@ -659,14 +660,75 @@ func cleanChatBody() { } } +// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings. +func convertJSONToMapStringString(jsonStr string) (map[string]string, error) { + var raw map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil { + return nil, err + } + result := make(map[string]string, len(raw)) + for k, v := range raw { + switch val := v.(type) { + case string: + result[k] = val + case float64: + result[k] = strconv.FormatFloat(val, 'f', -1, 64) + case int, int64, int32: + // json.Unmarshal converts numbers to float64, but handle other integer types if they appear + result[k] = fmt.Sprintf("%v", val) + case bool: + result[k] = strconv.FormatBool(val) + case nil: + result[k] = "" + default: + result[k] = fmt.Sprintf("%v", val) + } + } + return result, nil +} + +// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings. +func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) { + type tempFuncCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Args map[string]interface{} `json:"args"` + } + var temp tempFuncCall + if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil { + return nil, err + } + fc := &models.FuncCall{ + ID: temp.ID, + Name: temp.Name, + Args: make(map[string]string, len(temp.Args)), + } + for k, v := range temp.Args { + switch val := v.(type) { + case string: + fc.Args[k] = val + case float64: + fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64) + case int, int64, int32: + fc.Args[k] = fmt.Sprintf("%v", val) + case bool: + fc.Args[k] = strconv.FormatBool(val) + case nil: + fc.Args[k] = "" + default: + fc.Args[k] = fmt.Sprintf("%v", val) + } + } + return fc, nil +} + func findCall(msg, toolCall string, tv *tview.TextView) { fc := &models.FuncCall{} if toolCall != "" { // HTML-decode the tool call string to handle encoded characters like < -> <= decodedToolCall := html.UnescapeString(toolCall) - openAIToolMap := make(map[string]string) - // respect tool call - if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil { + openAIToolMap, err := convertJSONToMapStringString(decodedToolCall) + if err != nil { logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err) // Send error response to LLM so it can retry or handle the error toolResponseMsg := models.RoleMsg{ @@ -700,7 +762,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) { jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) // HTML-decode the JSON string to handle encoded characters like < -> <= decodedJsStr := html.UnescapeString(jsStr) - if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil { + var err error + fc, err = unmarshalFuncCall(decodedJsStr) + if err != nil { logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr) // Send error response to LLM so it can retry or handle the error toolResponseMsg := models.RoleMsg{ diff --git a/bot_test.go b/bot_test.go index 2d59c3c..d2956a9 100644 --- a/bot_test.go +++ b/bot_test.go @@ -152,4 +152,138 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) { } }) } +} + +func TestUnmarshalFuncCall(t *testing.T) { + tests := []struct { + name string + jsonStr string + want *models.FuncCall + wantErr bool + }{ + { + name: "simple websearch with numeric limit", + jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`, + want: &models.FuncCall{ + Name: "websearch", + Args: map[string]string{"query": "current weather in London", "limit": "3"}, + }, + wantErr: false, + }, + { + name: "string limit", + jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`, + want: &models.FuncCall{ + Name: "websearch", + Args: map[string]string{"query": "test", "limit": "5"}, + }, + wantErr: false, + }, + { + name: "boolean arg", + jsonStr: `{"name": "test", "args": {"flag": true}}`, + want: &models.FuncCall{ + Name: "test", + Args: map[string]string{"flag": "true"}, + }, + wantErr: false, + }, + { + name: "null arg", + jsonStr: `{"name": "test", "args": {"opt": null}}`, + want: &models.FuncCall{ + Name: "test", + Args: map[string]string{"opt": ""}, + }, + wantErr: false, + }, + { + name: "float arg", + jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`, + want: &models.FuncCall{ + Name: "test", + Args: map[string]string{"ratio": "0.5"}, + }, + wantErr: false, + }, + { + name: "invalid JSON", + jsonStr: `{invalid}`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := unmarshalFuncCall(tt.jsonStr) + if (err != nil) != tt.wantErr { + t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if got.Name != tt.want.Name { + t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name) + } + if len(got.Args) != len(tt.want.Args) { + t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args)) + } + for k, v := range tt.want.Args { + if got.Args[k] != v { + t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v) + } + } + }) + } +} + +func TestConvertJSONToMapStringString(t *testing.T) { + tests := []struct { + name string + jsonStr string + want map[string]string + wantErr bool + }{ + { + name: "simple map", + jsonStr: `{"query": "weather", "limit": 5}`, + want: map[string]string{"query": "weather", "limit": "5"}, + wantErr: false, + }, + { + name: "boolean and null", + jsonStr: `{"flag": true, "opt": null}`, + want: map[string]string{"flag": "true", "opt": ""}, + wantErr: false, + }, + { + name: "invalid JSON", + jsonStr: `{invalid`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := convertJSONToMapStringString(tt.jsonStr) + if (err != nil) != tt.wantErr { + t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if len(got) != len(tt.want) { + t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want)) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v) + } + } + }) + } }
\ No newline at end of file @@ -13,6 +13,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -126,7 +127,9 @@ under the topic: Adam's number is stored: </example_response> After that you are free to respond to the user. ` - basicCard = &models.CharCard{ + webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.` + readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.` + basicCard = &models.CharCard{ SysPrompt: basicSysMsg, FirstMsg: defaultFirstMsg, Role: "", @@ -141,8 +144,43 @@ After that you are free to respond to the user. // sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg} sysMap = map[string]*models.CharCard{"basic_sys": basicCard} sysLabels = []string{"basic_sys"} + + webAgentClient *agent.AgentClient + webAgentClientOnce sync.Once + webAgentsOnce sync.Once ) +// getWebAgentClient returns a singleton AgentClient for web agents. +func getWebAgentClient() *agent.AgentClient { + webAgentClientOnce.Do(func() { + if cfg == nil { + panic("cfg not initialized") + } + if logger == nil { + panic("logger not initialized") + } + getToken := func() string { + if chunkParser == nil { + return "" + } + return chunkParser.GetToken() + } + webAgentClient = agent.NewAgentClient(cfg, *logger, getToken) + }) + return webAgentClient +} + +// registerWebAgents registers WebAgentB instances for websearch and read_url tools. +func registerWebAgents() { + webAgentsOnce.Do(func() { + client := getWebAgentClient() + // Register websearch agent + agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt)) + // Register read_url agent + agent.Register("read_url", agent.NewWebAgentB(client, readURLSysPrompt)) + }) +} + // web search (depends on extra server) func websearch(args map[string]string) []byte { // make http request return bytes @@ -597,7 +635,6 @@ var globalTodoList = TodoList{ Items: []TodoItem{}, } - // Todo Management Tools func todoCreate(args map[string]string) []byte { task, ok := args["task"] @@ -851,6 +888,7 @@ var fnMap = map[string]fnSig{ // callToolWithAgent calls the tool and applies any registered agent. func callToolWithAgent(name string, args map[string]string) []byte { + registerWebAgents() f, ok := fnMap[name] if !ok { return []byte(fmt.Sprintf("tool %s not found", name)) @@ -862,16 +900,6 @@ func callToolWithAgent(name string, args map[string]string) []byte { return raw } -// registerDefaultAgents registers default agents for formatting. -func registerDefaultAgents() { - agent.Register("websearch", agent.DefaultFormatter("websearch")) - agent.Register("read_url", agent.DefaultFormatter("read_url")) -} - -func init() { - registerDefaultAgents() -} - // openai style def var baseTools = []models.Tool{ // websearch |
