summaryrefslogtreecommitdiff
path: root/agent
diff options
context:
space:
mode:
Diffstat (limited to 'agent')
-rw-r--r--agent/agent.go10
-rw-r--r--agent/request.go192
2 files changed, 188 insertions, 14 deletions
diff --git a/agent/agent.go b/agent/agent.go
index 5ad1ef1..8824ecb 100644
--- a/agent/agent.go
+++ b/agent/agent.go
@@ -33,3 +33,13 @@ func RegisterB(toolName string, a AgenterB) {
func RegisterA(toolNames []string, a AgenterA) {
RegistryA[a] = toolNames
}
+
+// Get returns the agent registered for the given tool name, or nil if none.
+func Get(toolName string) AgenterB {
+ return RegistryB[toolName]
+}
+
+// Register is a convenience wrapper for RegisterB.
+func Register(toolName string, a AgenterB) {
+ RegisterB(toolName, a)
+}
diff --git a/agent/request.go b/agent/request.go
index 2d557ac..bb4a80d 100644
--- a/agent/request.go
+++ b/agent/request.go
@@ -3,15 +3,32 @@ package agent
import (
"bytes"
"encoding/json"
+ "fmt"
"gf-lt/config"
"gf-lt/models"
"io"
"log/slog"
"net/http"
+ "strings"
)
var httpClient = &http.Client{}
+var defaultProps = map[string]float32{
+ "temperature": 0.8,
+ "dry_multiplier": 0.0,
+ "min_p": 0.05,
+ "n_predict": -1.0,
+}
+
+func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) {
+ isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions")
+ isChat = strings.Contains(api, "/chat/completions")
+ isDeepSeek = strings.Contains(api, "deepseek.com")
+ isOpenRouter = strings.Contains(api, "openrouter.ai")
+ return
+}
+
type AgentClient struct {
cfg *config.Config
getToken func() string
@@ -31,38 +48,185 @@ func (ag *AgentClient) Log() *slog.Logger {
}
func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) {
- agentConvo := []models.RoleMsg{
+ b, err := ag.buildRequest(sysprompt, msg)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(b), nil
+}
+
+// buildRequest creates the appropriate LLM request based on the current API endpoint.
+func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) {
+ api := ag.cfg.CurrentAPI
+ model := ag.cfg.CurrentModel
+ messages := []models.RoleMsg{
{Role: "system", Content: sysprompt},
{Role: "user", Content: msg},
}
- agentChat := &models.ChatBody{
- Model: ag.cfg.CurrentModel,
- Stream: true,
- Messages: agentConvo,
+
+ // Determine API type
+ isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api)
+ ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter)
+
+ // Build prompt for completion endpoints
+ if isCompletion {
+ var sb strings.Builder
+ for _, m := range messages {
+ sb.WriteString(m.ToPrompt())
+ sb.WriteString("\n")
+ }
+ prompt := strings.TrimSpace(sb.String())
+
+ if isDeepSeek {
+ // DeepSeek completion
+ req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{})
+ req.Stream = false // Agents don't need streaming
+ return json.Marshal(req)
+ } else if isOpenRouter {
+ // OpenRouter completion
+ req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{})
+ req.Stream = false // Agents don't need streaming
+ return json.Marshal(req)
+ } else {
+ // Assume llama.cpp completion
+ req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{})
+ req.Stream = false // Agents don't need streaming
+ return json.Marshal(req)
+ }
}
- b, err := json.Marshal(agentChat)
- if err != nil {
- ag.log.Error("failed to form agent msg", "error", err)
- return nil, err
+
+ // Chat completions endpoints
+ if isChat || !isCompletion {
+ chatBody := &models.ChatBody{
+ Model: model,
+ Stream: false, // Agents don't need streaming
+ Messages: messages,
+ }
+
+ if isDeepSeek {
+ // DeepSeek chat
+ req := models.NewDSChatReq(*chatBody)
+ return json.Marshal(req)
+ } else if isOpenRouter {
+ // OpenRouter chat
+ req := models.NewOpenRouterChatReq(*chatBody, defaultProps)
+ return json.Marshal(req)
+ } else {
+ // Assume llama.cpp chat (OpenAI format)
+ req := models.OpenAIReq{
+ ChatBody: chatBody,
+ Tools: nil,
+ }
+ return json.Marshal(req)
+ }
}
- return bytes.NewReader(b), nil
+
+ // Fallback (should not reach here)
+ ag.log.Warn("unknown API, using default chat completions format", "api", api)
+ chatBody := &models.ChatBody{
+ Model: model,
+ Stream: false, // Agents don't need streaming
+ Messages: messages,
+ }
+ return json.Marshal(chatBody)
}
func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) {
- req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, body)
+ // Read the body for debugging (but we need to recreate it for the request)
+ bodyBytes, err := io.ReadAll(body)
+ if err != nil {
+ ag.log.Error("failed to read request body", "error", err)
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes))
if err != nil {
- ag.log.Error("llamacpp api", "error", err)
+ ag.log.Error("failed to create request", "error", err)
return nil, err
}
req.Header.Add("Accept", "application/json")
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", "Bearer "+ag.getToken())
req.Header.Set("Accept-Encoding", "gzip")
+
+ ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)]))
+
resp, err := httpClient.Do(req)
if err != nil {
- ag.log.Error("llamacpp api", "error", err)
+ ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI)
return nil, err
}
defer resp.Body.Close()
- return io.ReadAll(resp.Body)
+
+ responseBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ ag.log.Error("failed to read response", "error", err)
+ return nil, err
+ }
+
+ if resp.StatusCode >= 400 {
+ ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)]))
+ return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)]))
+ }
+
+ // Parse response and extract text content
+ text, err := extractTextFromResponse(responseBytes)
+ if err != nil {
+ ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)]))
+ // Return raw response as fallback
+ return responseBytes, nil
+ }
+
+ return []byte(text), nil
+}
+
+// extractTextFromResponse parses common LLM response formats and extracts the text content.
+func extractTextFromResponse(data []byte) (string, error) {
+ // Try to parse as generic JSON first
+ var genericResp map[string]interface{}
+ if err := json.Unmarshal(data, &genericResp); err != nil {
+ // Not JSON, return as string
+ return string(data), nil
+ }
+
+ // Check for OpenAI chat completion format
+ if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 {
+ if firstChoice, ok := choices[0].(map[string]interface{}); ok {
+ // Chat completion: choices[0].message.content
+ if message, ok := firstChoice["message"].(map[string]interface{}); ok {
+ if content, ok := message["content"].(string); ok {
+ return content, nil
+ }
+ }
+ // Completion: choices[0].text
+ if text, ok := firstChoice["text"].(string); ok {
+ return text, nil
+ }
+ // Delta format for streaming (should not happen with stream: false)
+ if delta, ok := firstChoice["delta"].(map[string]interface{}); ok {
+ if content, ok := delta["content"].(string); ok {
+ return content, nil
+ }
+ }
+ }
+ }
+
+ // Check for llama.cpp completion format
+ if content, ok := genericResp["content"].(string); ok {
+ return content, nil
+ }
+
+ // Unknown format, return pretty-printed JSON
+ prettyJSON, err := json.MarshalIndent(genericResp, "", " ")
+ if err != nil {
+ return string(data), nil
+ }
+ return string(prettyJSON), nil
+}
+
+func min(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
}