diff options
Diffstat (limited to 'agent/request.go')
| -rw-r--r-- | agent/request.go | 145 |
1 files changed, 72 insertions, 73 deletions
diff --git a/agent/request.go b/agent/request.go index f42b06e..754f16e 100644 --- a/agent/request.go +++ b/agent/request.go @@ -30,12 +30,16 @@ func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) } type AgentClient struct { - cfg *config.Config - getToken func() string - log slog.Logger + cfg *config.Config + getToken func() string + log *slog.Logger + chatBody *models.ChatBody + sysprompt string + lastToolCallID string + tools []models.Tool } -func NewAgentClient(cfg *config.Config, log slog.Logger, gt func() string) *AgentClient { +func NewAgentClient(cfg *config.Config, log *slog.Logger, gt func() string) *AgentClient { return &AgentClient{ cfg: cfg, getToken: gt, @@ -44,93 +48,99 @@ func NewAgentClient(cfg *config.Config, log slog.Logger, gt func() string) *Agen } func (ag *AgentClient) Log() *slog.Logger { - return &ag.log + return ag.log } -func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) { - b, err := ag.buildRequest(sysprompt, msg) +func (ag *AgentClient) FormFirstMsg(sysprompt, msg string) (io.Reader, error) { + ag.sysprompt = sysprompt + ag.chatBody = &models.ChatBody{ + Messages: []models.RoleMsg{ + {Role: "system", Content: ag.sysprompt}, + {Role: "user", Content: msg}, + }, + Stream: false, + Model: ag.cfg.CurrentModel, + } + b, err := ag.buildRequest() if err != nil { return nil, err } return bytes.NewReader(b), nil } -// buildRequest creates the appropriate LLM request based on the current API endpoint. -func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) { - api := ag.cfg.CurrentAPI - model := ag.cfg.CurrentModel - messages := []models.RoleMsg{ - {Role: "system", Content: sysprompt}, - {Role: "user", Content: msg}, +func (ag *AgentClient) FormMsg(msg string) (io.Reader, error) { + m := models.RoleMsg{ + Role: "tool", Content: msg, } + ag.chatBody.Messages = append(ag.chatBody.Messages, m) + b, err := ag.buildRequest() + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} - // Determine API type - isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api) - ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter) +func (ag *AgentClient) FormMsgWithToolCallID(msg, toolCallID string) (io.Reader, error) { + m := models.RoleMsg{ + Role: "tool", + Content: msg, + ToolCallID: toolCallID, + } + ag.chatBody.Messages = append(ag.chatBody.Messages, m) + b, err := ag.buildRequest() + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} +// buildRequest creates the appropriate LLM request based on the current API endpoint. +func (ag *AgentClient) buildRequest() ([]byte, error) { + isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(ag.cfg.CurrentAPI) + ag.log.Debug("agent building request", "api", ag.cfg.CurrentAPI, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter) // Build prompt for completion endpoints if isCompletion { var sb strings.Builder - for i := range messages { - sb.WriteString(messages[i].ToPrompt()) + for i := range ag.chatBody.Messages { + sb.WriteString(ag.chatBody.Messages[i].ToPrompt()) sb.WriteString("\n") } prompt := strings.TrimSpace(sb.String()) - switch { case isDeepSeek: // DeepSeek completion - req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{}) + req := models.NewDSCompletionReq(prompt, ag.chatBody.Model, defaultProps["temperature"], []string{}) req.Stream = false // Agents don't need streaming return json.Marshal(req) case isOpenRouter: // OpenRouter completion - req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{}) + req := models.NewOpenRouterCompletionReq(ag.chatBody.Model, prompt, defaultProps, []string{}) req.Stream = false // Agents don't need streaming return json.Marshal(req) default: // Assume llama.cpp completion - req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{}) + req := models.NewLCPReq(prompt, ag.chatBody.Model, nil, defaultProps, []string{}) req.Stream = false // Agents don't need streaming return json.Marshal(req) } } - - // Chat completions endpoints - if isChat || !isCompletion { - chatBody := &models.ChatBody{ - Model: model, - Stream: false, // Agents don't need streaming - Messages: messages, - } - - switch { - case isDeepSeek: - // DeepSeek chat - req := models.NewDSChatReq(*chatBody) - return json.Marshal(req) - case isOpenRouter: - // OpenRouter chat - agents don't use reasoning by default - req := models.NewOpenRouterChatReq(*chatBody, defaultProps, "") - return json.Marshal(req) - default: - // Assume llama.cpp chat (OpenAI format) - req := models.OpenAIReq{ - ChatBody: chatBody, - Tools: nil, - } - return json.Marshal(req) + switch { + case isDeepSeek: + // DeepSeek chat + req := models.NewDSChatReq(*ag.chatBody) + return json.Marshal(req) + case isOpenRouter: + // OpenRouter chat - agents don't use reasoning by default + req := models.NewOpenRouterChatReq(*ag.chatBody, defaultProps, ag.cfg.ReasoningEffort) + return json.Marshal(req) + default: + // Assume llama.cpp chat (OpenAI format) + req := models.OpenAIReq{ + ChatBody: ag.chatBody, + Tools: ag.tools, } + return json.Marshal(req) } - - // Fallback (should not reach here) - ag.log.Warn("unknown API, using default chat completions format", "api", api) - chatBody := &models.ChatBody{ - Model: model, - Stream: false, // Agents don't need streaming - Messages: messages, - } - return json.Marshal(chatBody) } func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { @@ -165,7 +175,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)])) return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)])) } - // Parse response and extract text content text, err := extractTextFromResponse(responseBytes) if err != nil { @@ -179,17 +188,16 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { // extractTextFromResponse parses common LLM response formats and extracts the text content. func extractTextFromResponse(data []byte) (string, error) { // Try to parse as generic JSON first - var genericResp map[string]interface{} + var genericResp map[string]any if err := json.Unmarshal(data, &genericResp); err != nil { // Not JSON, return as string return string(data), nil } - // Check for OpenAI chat completion format - if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { - if firstChoice, ok := choices[0].(map[string]interface{}); ok { + if choices, ok := genericResp["choices"].([]any); ok && len(choices) > 0 { + if firstChoice, ok := choices[0].(map[string]any); ok { // Chat completion: choices[0].message.content - if message, ok := firstChoice["message"].(map[string]interface{}); ok { + if message, ok := firstChoice["message"].(map[string]any); ok { if content, ok := message["content"].(string); ok { return content, nil } @@ -199,19 +207,17 @@ func extractTextFromResponse(data []byte) (string, error) { return text, nil } // Delta format for streaming (should not happen with stream: false) - if delta, ok := firstChoice["delta"].(map[string]interface{}); ok { + if delta, ok := firstChoice["delta"].(map[string]any); ok { if content, ok := delta["content"].(string); ok { return content, nil } } } } - // Check for llama.cpp completion format if content, ok := genericResp["content"].(string); ok { return content, nil } - // Unknown format, return pretty-printed JSON prettyJSON, err := json.MarshalIndent(genericResp, "", " ") if err != nil { @@ -219,10 +225,3 @@ func extractTextFromResponse(data []byte) (string, error) { } return string(prettyJSON), nil } - -func min(a, b int) int { - if a < b { - return a - } - return b -} |
