summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--agent/agent.go10
-rw-r--r--agent/request.go192
-rw-r--r--bot.go72
-rw-r--r--bot_test.go134
-rw-r--r--tools.go52
5 files changed, 430 insertions, 30 deletions
diff --git a/agent/agent.go b/agent/agent.go
index 5ad1ef1..8824ecb 100644
--- a/agent/agent.go
+++ b/agent/agent.go
@@ -33,3 +33,13 @@ func RegisterB(toolName string, a AgenterB) {
func RegisterA(toolNames []string, a AgenterA) {
RegistryA[a] = toolNames
}
+
+// Get returns the agent registered for the given tool name, or nil if none.
+func Get(toolName string) AgenterB {
+ return RegistryB[toolName]
+}
+
+// Register is a convenience wrapper for RegisterB.
+func Register(toolName string, a AgenterB) {
+ RegisterB(toolName, a)
+}
diff --git a/agent/request.go b/agent/request.go
index 2d557ac..bb4a80d 100644
--- a/agent/request.go
+++ b/agent/request.go
@@ -3,15 +3,32 @@ package agent
import (
"bytes"
"encoding/json"
+ "fmt"
"gf-lt/config"
"gf-lt/models"
"io"
"log/slog"
"net/http"
+ "strings"
)
var httpClient = &http.Client{}
+var defaultProps = map[string]float32{
+ "temperature": 0.8,
+ "dry_multiplier": 0.0,
+ "min_p": 0.05,
+ "n_predict": -1.0,
+}
+
+func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) {
+ isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions")
+ isChat = strings.Contains(api, "/chat/completions")
+ isDeepSeek = strings.Contains(api, "deepseek.com")
+ isOpenRouter = strings.Contains(api, "openrouter.ai")
+ return
+}
+
type AgentClient struct {
cfg *config.Config
getToken func() string
@@ -31,38 +48,185 @@ func (ag *AgentClient) Log() *slog.Logger {
}
func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) {
- agentConvo := []models.RoleMsg{
+ b, err := ag.buildRequest(sysprompt, msg)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(b), nil
+}
+
+// buildRequest creates the appropriate LLM request based on the current API endpoint.
+func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) {
+ api := ag.cfg.CurrentAPI
+ model := ag.cfg.CurrentModel
+ messages := []models.RoleMsg{
{Role: "system", Content: sysprompt},
{Role: "user", Content: msg},
}
- agentChat := &models.ChatBody{
- Model: ag.cfg.CurrentModel,
- Stream: true,
- Messages: agentConvo,
+
+ // Determine API type
+ isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api)
+ ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter)
+
+ // Build prompt for completion endpoints
+ if isCompletion {
+ var sb strings.Builder
+ for _, m := range messages {
+ sb.WriteString(m.ToPrompt())
+ sb.WriteString("\n")
+ }
+ prompt := strings.TrimSpace(sb.String())
+
+ if isDeepSeek {
+ // DeepSeek completion
+ req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{})
+ req.Stream = false // Agents don't need streaming
+ return json.Marshal(req)
+ } else if isOpenRouter {
+ // OpenRouter completion
+ req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{})
+ req.Stream = false // Agents don't need streaming
+ return json.Marshal(req)
+ } else {
+ // Assume llama.cpp completion
+ req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{})
+ req.Stream = false // Agents don't need streaming
+ return json.Marshal(req)
+ }
}
- b, err := json.Marshal(agentChat)
- if err != nil {
- ag.log.Error("failed to form agent msg", "error", err)
- return nil, err
+
+ // Chat completions endpoints
+ if isChat || !isCompletion {
+ chatBody := &models.ChatBody{
+ Model: model,
+ Stream: false, // Agents don't need streaming
+ Messages: messages,
+ }
+
+ if isDeepSeek {
+ // DeepSeek chat
+ req := models.NewDSChatReq(*chatBody)
+ return json.Marshal(req)
+ } else if isOpenRouter {
+ // OpenRouter chat
+ req := models.NewOpenRouterChatReq(*chatBody, defaultProps)
+ return json.Marshal(req)
+ } else {
+ // Assume llama.cpp chat (OpenAI format)
+ req := models.OpenAIReq{
+ ChatBody: chatBody,
+ Tools: nil,
+ }
+ return json.Marshal(req)
+ }
}
- return bytes.NewReader(b), nil
+
+ // Fallback (should not reach here)
+ ag.log.Warn("unknown API, using default chat completions format", "api", api)
+ chatBody := &models.ChatBody{
+ Model: model,
+ Stream: false, // Agents don't need streaming
+ Messages: messages,
+ }
+ return json.Marshal(chatBody)
}
func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
- req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, body)
+ // Read the body for debugging (but we need to recreate it for the request)
+ bodyBytes, err := io.ReadAll(body)
+ if err != nil {
+ ag.log.Error("failed to read request body", "error", err)
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
if err != nil {
- ag.log.Error("llamacpp api", "error", err)
+ ag.log.Error("failed to create request", "error", err)
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+ag.getToken())
req.Header.Set("Accept-Encoding", "gzip")
+
+ ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
+
resp, err := httpClient.Do(req)
if err != nil {
- ag.log.Error("llamacpp api", "error", err)
+ ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
return nil, err
}
defer resp.Body.Close()
- return io.ReadAll(resp.Body)
+
+ responseBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ ag.log.Error("failed to read response", "error", err)
+ return nil, err
+ }
+
+ if resp.StatusCode >= 400 {
+ ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
+ return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
+ }
+
+ // Parse response and extract text content
+ text, err := extractTextFromResponse(responseBytes)
+ if err != nil {
+ ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)]))
+ // Return raw response as fallback
+ return responseBytes, nil
+ }
+
+ return []byte(text), nil
+}
+
+// extractTextFromResponse parses common LLM response formats and extracts the text content.
+func extractTextFromResponse(data []byte) (string, error) {
+ // Try to parse as generic JSON first
+ var genericResp map[string]interface{}
+ if err := json.Unmarshal(data, &genericResp); err != nil {
+ // Not JSON, return as string
+ return string(data), nil
+ }
+
+ // Check for OpenAI chat completion format
+ if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
+ if firstChoice, ok := choices[0].(map[string]interface{}); ok {
+ // Chat completion: choices[0].message.content
+ if message, ok := firstChoice["message"].(map[string]interface{}); ok {
+ if content, ok := message["content"].(string); ok {
+ return content, nil
+ }
+ }
+ // Completion: choices[0].text
+ if text, ok := firstChoice["text"].(string); ok {
+ return text, nil
+ }
+ // Delta format for streaming (should not happen with stream: false)
+ if delta, ok := firstChoice["delta"].(map[string]interface{}); ok {
+ if content, ok := delta["content"].(string); ok {
+ return content, nil
+ }
+ }
+ }
+ }
+
+ // Check for llama.cpp completion format
+ if content, ok := genericResp["content"].(string); ok {
+ return content, nil
+ }
+
+ // Unknown format, return pretty-printed JSON
+ prettyJSON, err := json.MarshalIndent(genericResp, "", " ")
+ if err != nil {
+ return string(data), nil
+ }
+ return string(prettyJSON), nil
+}
+
+func min(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
}
diff --git a/bot.go b/bot.go
index 8206c63..779278e 100644
--- a/bot.go
+++ b/bot.go
@@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "strconv"
"gf-lt/config"
"gf-lt/extra"
"gf-lt/models"
@@ -659,14 +660,75 @@ func cleanChatBody() {
}
}
+// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings.
+func convertJSONToMapStringString(jsonStr string) (map[string]string, error) {
+ var raw map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
+ return nil, err
+ }
+ result := make(map[string]string, len(raw))
+ for k, v := range raw {
+ switch val := v.(type) {
+ case string:
+ result[k] = val
+ case float64:
+ result[k] = strconv.FormatFloat(val, 'f', -1, 64)
+ case int, int64, int32:
+ // json.Unmarshal converts numbers to float64, but handle other integer types if they appear
+ result[k] = fmt.Sprintf("%v", val)
+ case bool:
+ result[k] = strconv.FormatBool(val)
+ case nil:
+ result[k] = ""
+ default:
+ result[k] = fmt.Sprintf("%v", val)
+ }
+ }
+ return result, nil
+}
+
+// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings.
+func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) {
+ type tempFuncCall struct {
+ ID string `json:"id,omitempty"`
+ Name string `json:"name"`
+ Args map[string]interface{} `json:"args"`
+ }
+ var temp tempFuncCall
+ if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil {
+ return nil, err
+ }
+ fc := &models.FuncCall{
+ ID: temp.ID,
+ Name: temp.Name,
+ Args: make(map[string]string, len(temp.Args)),
+ }
+ for k, v := range temp.Args {
+ switch val := v.(type) {
+ case string:
+ fc.Args[k] = val
+ case float64:
+ fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64)
+ case int, int64, int32:
+ fc.Args[k] = fmt.Sprintf("%v", val)
+ case bool:
+ fc.Args[k] = strconv.FormatBool(val)
+ case nil:
+ fc.Args[k] = ""
+ default:
+ fc.Args[k] = fmt.Sprintf("%v", val)
+ }
+ }
+ return fc, nil
+}
+
func findCall(msg, toolCall string, tv *tview.TextView) {
fc := &models.FuncCall{}
if toolCall != "" {
// HTML-decode the tool call string to handle encoded characters like &lt; -> <=
decodedToolCall := html.UnescapeString(toolCall)
- openAIToolMap := make(map[string]string)
- // respect tool call
- if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil {
+ openAIToolMap, err := convertJSONToMapStringString(decodedToolCall)
+ if err != nil {
logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err)
// Send error response to LLM so it can retry or handle the error
toolResponseMsg := models.RoleMsg{
@@ -700,7 +762,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
// HTML-decode the JSON string to handle encoded characters like &lt; -> <=
decodedJsStr := html.UnescapeString(jsStr)
- if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil {
+ var err error
+ fc, err = unmarshalFuncCall(decodedJsStr)
+ if err != nil {
logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr)
// Send error response to LLM so it can retry or handle the error
toolResponseMsg := models.RoleMsg{
diff --git a/bot_test.go b/bot_test.go
index 2d59c3c..d2956a9 100644
--- a/bot_test.go
+++ b/bot_test.go
@@ -152,4 +152,138 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
}
})
}
+}
+
+func TestUnmarshalFuncCall(t *testing.T) {
+ tests := []struct {
+ name string
+ jsonStr string
+ want *models.FuncCall
+ wantErr bool
+ }{
+ {
+ name: "simple websearch with numeric limit",
+ jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`,
+ want: &models.FuncCall{
+ Name: "websearch",
+ Args: map[string]string{"query": "current weather in London", "limit": "3"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "string limit",
+ jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`,
+ want: &models.FuncCall{
+ Name: "websearch",
+ Args: map[string]string{"query": "test", "limit": "5"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "boolean arg",
+ jsonStr: `{"name": "test", "args": {"flag": true}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"flag": "true"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "null arg",
+ jsonStr: `{"name": "test", "args": {"opt": null}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"opt": ""},
+ },
+ wantErr: false,
+ },
+ {
+ name: "float arg",
+ jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"ratio": "0.5"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ jsonStr: `{invalid}`,
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := unmarshalFuncCall(tt.jsonStr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if got.Name != tt.want.Name {
+ t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name)
+ }
+ if len(got.Args) != len(tt.want.Args) {
+ t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args))
+ }
+ for k, v := range tt.want.Args {
+ if got.Args[k] != v {
+ t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v)
+ }
+ }
+ })
+ }
+}
+
+func TestConvertJSONToMapStringString(t *testing.T) {
+ tests := []struct {
+ name string
+ jsonStr string
+ want map[string]string
+ wantErr bool
+ }{
+ {
+ name: "simple map",
+ jsonStr: `{"query": "weather", "limit": 5}`,
+ want: map[string]string{"query": "weather", "limit": "5"},
+ wantErr: false,
+ },
+ {
+ name: "boolean and null",
+ jsonStr: `{"flag": true, "opt": null}`,
+ want: map[string]string{"flag": "true", "opt": ""},
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ jsonStr: `{invalid`,
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := convertJSONToMapStringString(tt.jsonStr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if len(got) != len(tt.want) {
+ t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want))
+ }
+ for k, v := range tt.want {
+ if got[k] != v {
+ t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v)
+ }
+ }
+ })
+ }
} \ No newline at end of file
diff --git a/tools.go b/tools.go
index e4af7ad..49d8192 100644
--- a/tools.go
+++ b/tools.go
@@ -13,6 +13,7 @@ import (
"regexp"
"strconv"
"strings"
+ "sync"
"time"
)
@@ -126,7 +127,9 @@ under the topic: Adam's number is stored:
</example_response>
After that you are free to respond to the user.
`
- basicCard = &models.CharCard{
+ webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.`
+ readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.`
+ basicCard = &models.CharCard{
SysPrompt: basicSysMsg,
FirstMsg: defaultFirstMsg,
Role: "",
@@ -141,8 +144,43 @@ After that you are free to respond to the user.
// sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg}
sysMap = map[string]*models.CharCard{"basic_sys": basicCard}
sysLabels = []string{"basic_sys"}
+
+ webAgentClient *agent.AgentClient
+ webAgentClientOnce sync.Once
+ webAgentsOnce sync.Once
)
+// getWebAgentClient returns a singleton AgentClient for web agents.
+func getWebAgentClient() *agent.AgentClient {
+ webAgentClientOnce.Do(func() {
+ if cfg == nil {
+ panic("cfg not initialized")
+ }
+ if logger == nil {
+ panic("logger not initialized")
+ }
+ getToken := func() string {
+ if chunkParser == nil {
+ return ""
+ }
+ return chunkParser.GetToken()
+ }
+ webAgentClient = agent.NewAgentClient(cfg, *logger, getToken)
+ })
+ return webAgentClient
+}
+
+// registerWebAgents registers WebAgentB instances for websearch and read_url tools.
+func registerWebAgents() {
+ webAgentsOnce.Do(func() {
+ client := getWebAgentClient()
+ // Register websearch agent
+ agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt))
+ // Register read_url agent
+ agent.Register("read_url", agent.NewWebAgentB(client, readURLSysPrompt))
+ })
+}
+
// web search (depends on extra server)
func websearch(args map[string]string) []byte {
// make http request return bytes
@@ -597,7 +635,6 @@ var globalTodoList = TodoList{
Items: []TodoItem{},
}
-
// Todo Management Tools
func todoCreate(args map[string]string) []byte {
task, ok := args["task"]
@@ -851,6 +888,7 @@ var fnMap = map[string]fnSig{
// callToolWithAgent calls the tool and applies any registered agent.
func callToolWithAgent(name string, args map[string]string) []byte {
+ registerWebAgents()
f, ok := fnMap[name]
if !ok {
return []byte(fmt.Sprintf("tool %s not found", name))
@@ -862,16 +900,6 @@ func callToolWithAgent(name string, args map[string]string) []byte {
return raw
}
-// registerDefaultAgents registers default agents for formatting.
-func registerDefaultAgents() {
- agent.Register("websearch", agent.DefaultFormatter("websearch"))
- agent.Register("read_url", agent.DefaultFormatter("read_url"))
-}
-
-func init() {
- registerDefaultAgents()
-}
-
// openai style def
var baseTools = []models.Tool{
// websearch