diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-12-19 15:39:55 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-12-19 15:39:55 +0300 |
| commit | f779f039745f97f08f25967214d07716ce213326 (patch) | |
| tree | 9868dfe5e77845e27d4fe8b36f993a5fade2ea7b /agent | |
| parent | a875abcf198dd2f85c518f8bf2c599db66d3e69f (diff) | |
Enha: agent request builder
Diffstat (limited to 'agent')
| -rw-r--r-- | agent/agent.go | 10 | ||||
| -rw-r--r-- | agent/request.go | 192 |
2 files changed, 188 insertions, 14 deletions
diff --git a/agent/agent.go b/agent/agent.go index 5ad1ef1..8824ecb 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -33,3 +33,13 @@ func RegisterB(toolName string, a AgenterB) { func RegisterA(toolNames []string, a AgenterA) { RegistryA[a] = toolNames } + +// Get returns the agent registered for the given tool name, or nil if none. +func Get(toolName string) AgenterB { + return RegistryB[toolName] +} + +// Register is a convenience wrapper for RegisterB. +func Register(toolName string, a AgenterB) { + RegisterB(toolName, a) +} diff --git a/agent/request.go b/agent/request.go index 2d557ac..bb4a80d 100644 --- a/agent/request.go +++ b/agent/request.go @@ -3,15 +3,32 @@ package agent import ( "bytes" "encoding/json" + "fmt" "gf-lt/config" "gf-lt/models" "io" "log/slog" "net/http" + "strings" ) var httpClient = &http.Client{} +var defaultProps = map[string]float32{ + "temperature": 0.8, + "dry_multiplier": 0.0, + "min_p": 0.05, + "n_predict": -1.0, +} + +func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) { + isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions") + isChat = strings.Contains(api, "/chat/completions") + isDeepSeek = strings.Contains(api, "deepseek.com") + isOpenRouter = strings.Contains(api, "openrouter.ai") + return +} + type AgentClient struct { cfg *config.Config getToken func() string @@ -31,38 +48,185 @@ func (ag *AgentClient) Log() *slog.Logger { } func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) { - agentConvo := []models.RoleMsg{ + b, err := ag.buildRequest(sysprompt, msg) + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} + +// buildRequest creates the appropriate LLM request based on the current API endpoint. +func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) { + api := ag.cfg.CurrentAPI + model := ag.cfg.CurrentModel + messages := []models.RoleMsg{ {Role: "system", Content: sysprompt}, {Role: "user", Content: msg}, } - agentChat := &models.ChatBody{ - Model: ag.cfg.CurrentModel, - Stream: true, - Messages: agentConvo, + + // Determine API type + isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api) + ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter) + + // Build prompt for completion endpoints + if isCompletion { + var sb strings.Builder + for _, m := range messages { + sb.WriteString(m.ToPrompt()) + sb.WriteString("\n") + } + prompt := strings.TrimSpace(sb.String()) + + if isDeepSeek { + // DeepSeek completion + req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } else if isOpenRouter { + // OpenRouter completion + req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } else { + // Assume llama.cpp completion + req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } } - b, err := json.Marshal(agentChat) - if err != nil { - ag.log.Error("failed to form agent msg", "error", err) - return nil, err + + // Chat completions endpoints + if isChat || !isCompletion { + chatBody := &models.ChatBody{ + Model: model, + Stream: false, // Agents don't need streaming + Messages: messages, + } + + if isDeepSeek { + // DeepSeek chat + req := models.NewDSChatReq(*chatBody) + return json.Marshal(req) + } else if isOpenRouter { + // OpenRouter chat + req := models.NewOpenRouterChatReq(*chatBody, defaultProps) + return json.Marshal(req) + } else { + // Assume llama.cpp chat (OpenAI format) + req := models.OpenAIReq{ + ChatBody: chatBody, + Tools: nil, + } + return json.Marshal(req) + } } - return bytes.NewReader(b), nil + + // Fallback (should not reach here) + ag.log.Warn("unknown API, using default chat completions format", "api", api) + chatBody := &models.ChatBody{ + Model: model, + Stream: false, // Agents don't need streaming + Messages: messages, + } + return json.Marshal(chatBody) } func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { - req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, body) + // Read the body for debugging (but we need to recreate it for the request) + bodyBytes, err := io.ReadAll(body) + if err != nil { + ag.log.Error("failed to read request body", "error", err) + return nil, err + } + + req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes)) if err != nil { - ag.log.Error("llamacpp api", "error", err) + ag.log.Error("failed to create request", "error", err) return nil, err } req.Header.Add("Accept", "application/json") req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+ag.getToken()) req.Header.Set("Accept-Encoding", "gzip") + + ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)])) + resp, err := httpClient.Do(req) if err != nil { - ag.log.Error("llamacpp api", "error", err) + ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI) return nil, err } defer resp.Body.Close() - return io.ReadAll(resp.Body) + + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + ag.log.Error("failed to read response", "error", err) + return nil, err + } + + if resp.StatusCode >= 400 { + ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)])) + return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)])) + } + + // Parse response and extract text content + text, err := extractTextFromResponse(responseBytes) + if err != nil { + ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)])) + // Return raw response as fallback + return responseBytes, nil + } + + return []byte(text), nil +} + +// extractTextFromResponse parses common LLM response formats and extracts the text content. +func extractTextFromResponse(data []byte) (string, error) { + // Try to parse as generic JSON first + var genericResp map[string]interface{} + if err := json.Unmarshal(data, &genericResp); err != nil { + // Not JSON, return as string + return string(data), nil + } + + // Check for OpenAI chat completion format + if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { + if firstChoice, ok := choices[0].(map[string]interface{}); ok { + // Chat completion: choices[0].message.content + if message, ok := firstChoice["message"].(map[string]interface{}); ok { + if content, ok := message["content"].(string); ok { + return content, nil + } + } + // Completion: choices[0].text + if text, ok := firstChoice["text"].(string); ok { + return text, nil + } + // Delta format for streaming (should not happen with stream: false) + if delta, ok := firstChoice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + return content, nil + } + } + } + } + + // Check for llama.cpp completion format + if content, ok := genericResp["content"].(string); ok { + return content, nil + } + + // Unknown format, return pretty-printed JSON + prettyJSON, err := json.MarshalIndent(genericResp, "", " ") + if err != nil { + return string(data), nil + } + return string(prettyJSON), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b } |
