summaryrefslogtreecommitdiff
path: root/agent/request.go
diff options
context:
space:
mode:
Diffstat (limited to 'agent/request.go')
-rw-r--r--agent/request.go145
1 files changed, 72 insertions, 73 deletions
diff --git a/agent/request.go b/agent/request.go
index f42b06e..754f16e 100644
--- a/agent/request.go
+++ b/agent/request.go
@@ -30,12 +30,16 @@ func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool)
}
type AgentClient struct {
- cfg *config.Config
- getToken func() string
- log slog.Logger
+ cfg *config.Config
+ getToken func() string
+ log *slog.Logger
+ chatBody *models.ChatBody
+ sysprompt string
+ lastToolCallID string
+ tools []models.Tool
}
-func NewAgentClient(cfg *config.Config, log slog.Logger, gt func() string) *AgentClient {
+func NewAgentClient(cfg *config.Config, log *slog.Logger, gt func() string) *AgentClient {
return &AgentClient{
cfg: cfg,
getToken: gt,
@@ -44,93 +48,99 @@ func NewAgentClient(cfg *config.Config, log slog.Logger, gt func() string) *Agen
}
func (ag *AgentClient) Log() *slog.Logger {
- return &ag.log
+ return ag.log
}
-func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) {
- b, err := ag.buildRequest(sysprompt, msg)
+func (ag *AgentClient) FormFirstMsg(sysprompt, msg string) (io.Reader, error) {
+ ag.sysprompt = sysprompt
+ ag.chatBody = &models.ChatBody{
+ Messages: []models.RoleMsg{
+ {Role: "system", Content: ag.sysprompt},
+ {Role: "user", Content: msg},
+ },
+ Stream: false,
+ Model: ag.cfg.CurrentModel,
+ }
+ b, err := ag.buildRequest()
if err != nil {
return nil, err
}
return bytes.NewReader(b), nil
}
-// buildRequest creates the appropriate LLM request based on the current API endpoint.
-func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) {
- api := ag.cfg.CurrentAPI
- model := ag.cfg.CurrentModel
- messages := []models.RoleMsg{
- {Role: "system", Content: sysprompt},
- {Role: "user", Content: msg},
+func (ag *AgentClient) FormMsg(msg string) (io.Reader, error) {
+ m := models.RoleMsg{
+ Role: "tool", Content: msg,
}
+ ag.chatBody.Messages = append(ag.chatBody.Messages, m)
+ b, err := ag.buildRequest()
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(b), nil
+}
- // Determine API type
- isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api)
- ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter)
+func (ag *AgentClient) FormMsgWithToolCallID(msg, toolCallID string) (io.Reader, error) {
+ m := models.RoleMsg{
+ Role: "tool",
+ Content: msg,
+ ToolCallID: toolCallID,
+ }
+ ag.chatBody.Messages = append(ag.chatBody.Messages, m)
+ b, err := ag.buildRequest()
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(b), nil
+}
+// buildRequest creates the appropriate LLM request based on the current API endpoint.
+func (ag *AgentClient) buildRequest() ([]byte, error) {
+ isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(ag.cfg.CurrentAPI)
+ ag.log.Debug("agent building request", "api", ag.cfg.CurrentAPI, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter)
// Build prompt for completion endpoints
if isCompletion {
var sb strings.Builder
- for i := range messages {
- sb.WriteString(messages[i].ToPrompt())
+ for i := range ag.chatBody.Messages {
+ sb.WriteString(ag.chatBody.Messages[i].ToPrompt())
sb.WriteString("\n")
}
prompt := strings.TrimSpace(sb.String())
-
switch {
case isDeepSeek:
// DeepSeek completion
- req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{})
+ req := models.NewDSCompletionReq(prompt, ag.chatBody.Model, defaultProps["temperature"], []string{})
req.Stream = false // Agents don't need streaming
return json.Marshal(req)
case isOpenRouter:
// OpenRouter completion
- req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{})
+ req := models.NewOpenRouterCompletionReq(ag.chatBody.Model, prompt, defaultProps, []string{})
req.Stream = false // Agents don't need streaming
return json.Marshal(req)
default:
// Assume llama.cpp completion
- req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{})
+ req := models.NewLCPReq(prompt, ag.chatBody.Model, nil, defaultProps, []string{})
req.Stream = false // Agents don't need streaming
return json.Marshal(req)
}
}
-
- // Chat completions endpoints
- if isChat || !isCompletion {
- chatBody := &models.ChatBody{
- Model: model,
- Stream: false, // Agents don't need streaming
- Messages: messages,
- }
-
- switch {
- case isDeepSeek:
- // DeepSeek chat
- req := models.NewDSChatReq(*chatBody)
- return json.Marshal(req)
- case isOpenRouter:
- // OpenRouter chat - agents don't use reasoning by default
- req := models.NewOpenRouterChatReq(*chatBody, defaultProps, "")
- return json.Marshal(req)
- default:
- // Assume llama.cpp chat (OpenAI format)
- req := models.OpenAIReq{
- ChatBody: chatBody,
- Tools: nil,
- }
- return json.Marshal(req)
+ switch {
+ case isDeepSeek:
+ // DeepSeek chat
+ req := models.NewDSChatReq(*ag.chatBody)
+ return json.Marshal(req)
+ case isOpenRouter:
+ // OpenRouter chat - agents don't use reasoning by default
+ req := models.NewOpenRouterChatReq(*ag.chatBody, defaultProps, ag.cfg.ReasoningEffort)
+ return json.Marshal(req)
+ default:
+ // Assume llama.cpp chat (OpenAI format)
+ req := models.OpenAIReq{
+ ChatBody: ag.chatBody,
+ Tools: ag.tools,
}
+ return json.Marshal(req)
}
-
- // Fallback (should not reach here)
- ag.log.Warn("unknown API, using default chat completions format", "api", api)
- chatBody := &models.ChatBody{
- Model: model,
- Stream: false, // Agents don't need streaming
- Messages: messages,
- }
- return json.Marshal(chatBody)
}
func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
@@ -165,7 +175,6 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
}
-
// Parse response and extract text content
text, err := extractTextFromResponse(responseBytes)
if err != nil {
@@ -179,17 +188,16 @@ func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
// extractTextFromResponse parses common LLM response formats and extracts the text content.
func extractTextFromResponse(data []byte) (string, error) {
// Try to parse as generic JSON first
- var genericResp map[string]interface{}
+ var genericResp map[string]any
if err := json.Unmarshal(data, &genericResp); err != nil {
// Not JSON, return as string
return string(data), nil
}
-
// Check for OpenAI chat completion format
- if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
- if firstChoice, ok := choices[0].(map[string]interface{}); ok {
+ if choices, ok := genericResp["choices"].([]any); ok && len(choices) > 0 {
+ if firstChoice, ok := choices[0].(map[string]any); ok {
// Chat completion: choices[0].message.content
- if message, ok := firstChoice["message"].(map[string]interface{}); ok {
+ if message, ok := firstChoice["message"].(map[string]any); ok {
if content, ok := message["content"].(string); ok {
return content, nil
}
@@ -199,19 +207,17 @@ func extractTextFromResponse(data []byte) (string, error) {
return text, nil
}
// Delta format for streaming (should not happen with stream: false)
- if delta, ok := firstChoice["delta"].(map[string]interface{}); ok {
+ if delta, ok := firstChoice["delta"].(map[string]any); ok {
if content, ok := delta["content"].(string); ok {
return content, nil
}
}
}
}
-
// Check for llama.cpp completion format
if content, ok := genericResp["content"].(string); ok {
return content, nil
}
-
// Unknown format, return pretty-printed JSON
prettyJSON, err := json.MarshalIndent(genericResp, "", " ")
if err != nil {
@@ -219,10 +225,3 @@ func extractTextFromResponse(data []byte) (string, error) {
}
return string(prettyJSON), nil
}
-
-func min(a, b int) int {
- if a < b {
- return a
- }
- return b
-}