diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | agent/agent.go | 45 | ||||
| -rw-r--r-- | agent/request.go | 232 | ||||
| -rw-r--r-- | agent/webagent.go | 32 | ||||
| -rw-r--r-- | bot.go | 378 | ||||
| -rw-r--r-- | bot_test.go | 134 | ||||
| -rw-r--r-- | config/config.go | 2 | ||||
| -rw-r--r-- | helpfuncs.go | 24 | ||||
| -rw-r--r-- | llm.go | 8 | ||||
| -rw-r--r-- | main.go | 10 | ||||
| -rw-r--r-- | props_table.go | 109 | ||||
| -rw-r--r-- | tables.go | 65 | ||||
| -rw-r--r-- | tools.go | 56 | ||||
| -rw-r--r-- | tui.go | 99 |
14 files changed, 1029 insertions, 166 deletions
@@ -13,3 +13,4 @@ gf-lt gflt chat_exports/*.json ragimport +.env diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000..8824ecb --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,45 @@ +package agent + +// I see two types of agents possible: +// ones who do their own tools calls +// ones that works only with the output + +// A: main chat -> agent (handles everything: tool + processing) +// B: main chat -> tool -> agent (process tool output) + +// AgenterA gets a task "find out weather in london" +// proceeds to make tool calls on its own +type AgenterA interface { + ProcessTask(task string) []byte +} + +// AgenterB defines an interface for processing tool outputs +type AgenterB interface { + // Process takes the original tool arguments and the raw output from the tool, + // and returns a cleaned/summarized version suitable for the main LLM context + Process(args map[string]string, rawOutput []byte) []byte +} + +// registry holds mapping from tool names to agents +var RegistryB = make(map[string]AgenterB) +var RegistryA = make(map[AgenterA][]string) + +// Register adds an agent for a specific tool name +// If an agent already exists for the tool, it will be replaced +func RegisterB(toolName string, a AgenterB) { + RegistryB[toolName] = a +} + +func RegisterA(toolNames []string, a AgenterA) { + RegistryA[a] = toolNames +} + +// Get returns the agent registered for the given tool name, or nil if none. +func Get(toolName string) AgenterB { + return RegistryB[toolName] +} + +// Register is a convenience wrapper for RegisterB. +func Register(toolName string, a AgenterB) { + RegisterB(toolName, a) +} diff --git a/agent/request.go b/agent/request.go new file mode 100644 index 0000000..bb4a80d --- /dev/null +++ b/agent/request.go @@ -0,0 +1,232 @@ +package agent + +import ( + "bytes" + "encoding/json" + "fmt" + "gf-lt/config" + "gf-lt/models" + "io" + "log/slog" + "net/http" + "strings" +) + +var httpClient = &http.Client{} + +var defaultProps = map[string]float32{ + "temperature": 0.8, + "dry_multiplier": 0.0, + "min_p": 0.05, + "n_predict": -1.0, +} + +func detectAPI(api string) (isCompletion, isChat, isDeepSeek, isOpenRouter bool) { + isCompletion = strings.Contains(api, "/completion") && !strings.Contains(api, "/chat/completions") + isChat = strings.Contains(api, "/chat/completions") + isDeepSeek = strings.Contains(api, "deepseek.com") + isOpenRouter = strings.Contains(api, "openrouter.ai") + return +} + +type AgentClient struct { + cfg *config.Config + getToken func() string + log slog.Logger +} + +func NewAgentClient(cfg *config.Config, log slog.Logger, gt func() string) *AgentClient { + return &AgentClient{ + cfg: cfg, + getToken: gt, + log: log, + } +} + +func (ag *AgentClient) Log() *slog.Logger { + return &ag.log +} + +func (ag *AgentClient) FormMsg(sysprompt, msg string) (io.Reader, error) { + b, err := ag.buildRequest(sysprompt, msg) + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} + +// buildRequest creates the appropriate LLM request based on the current API endpoint. +func (ag *AgentClient) buildRequest(sysprompt, msg string) ([]byte, error) { + api := ag.cfg.CurrentAPI + model := ag.cfg.CurrentModel + messages := []models.RoleMsg{ + {Role: "system", Content: sysprompt}, + {Role: "user", Content: msg}, + } + + // Determine API type + isCompletion, isChat, isDeepSeek, isOpenRouter := detectAPI(api) + ag.log.Debug("agent building request", "api", api, "isCompletion", isCompletion, "isChat", isChat, "isDeepSeek", isDeepSeek, "isOpenRouter", isOpenRouter) + + // Build prompt for completion endpoints + if isCompletion { + var sb strings.Builder + for _, m := range messages { + sb.WriteString(m.ToPrompt()) + sb.WriteString("\n") + } + prompt := strings.TrimSpace(sb.String()) + + if isDeepSeek { + // DeepSeek completion + req := models.NewDSCompletionReq(prompt, model, defaultProps["temperature"], []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } else if isOpenRouter { + // OpenRouter completion + req := models.NewOpenRouterCompletionReq(model, prompt, defaultProps, []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } else { + // Assume llama.cpp completion + req := models.NewLCPReq(prompt, model, nil, defaultProps, []string{}) + req.Stream = false // Agents don't need streaming + return json.Marshal(req) + } + } + + // Chat completions endpoints + if isChat || !isCompletion { + chatBody := &models.ChatBody{ + Model: model, + Stream: false, // Agents don't need streaming + Messages: messages, + } + + if isDeepSeek { + // DeepSeek chat + req := models.NewDSChatReq(*chatBody) + return json.Marshal(req) + } else if isOpenRouter { + // OpenRouter chat + req := models.NewOpenRouterChatReq(*chatBody, defaultProps) + return json.Marshal(req) + } else { + // Assume llama.cpp chat (OpenAI format) + req := models.OpenAIReq{ + ChatBody: chatBody, + Tools: nil, + } + return json.Marshal(req) + } + } + + // Fallback (should not reach here) + ag.log.Warn("unknown API, using default chat completions format", "api", api) + chatBody := &models.ChatBody{ + Model: model, + Stream: false, // Agents don't need streaming + Messages: messages, + } + return json.Marshal(chatBody) +} + +func (ag *AgentClient) LLMRequest(body io.Reader) ([]byte, error) { + // Read the body for debugging (but we need to recreate it for the request) + bodyBytes, err := io.ReadAll(body) + if err != nil { + ag.log.Error("failed to read request body", "error", err) + return nil, err + } + + req, err := http.NewRequest("POST", ag.cfg.CurrentAPI, bytes.NewReader(bodyBytes)) + if err != nil { + ag.log.Error("failed to create request", "error", err) + return nil, err + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+ag.getToken()) + req.Header.Set("Accept-Encoding", "gzip") + + ag.log.Debug("agent LLM request", "url", ag.cfg.CurrentAPI, "body_preview", string(bodyBytes[:min(len(bodyBytes), 500)])) + + resp, err := httpClient.Do(req) + if err != nil { + ag.log.Error("llamacpp api request failed", "error", err, "url", ag.cfg.CurrentAPI) + return nil, err + } + defer resp.Body.Close() + + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + ag.log.Error("failed to read response", "error", err) + return nil, err + } + + if resp.StatusCode >= 400 { + ag.log.Error("agent LLM request failed", "status", resp.StatusCode, "response", string(responseBytes[:min(len(responseBytes), 1000)])) + return responseBytes, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(responseBytes[:min(len(responseBytes), 200)])) + } + + // Parse response and extract text content + text, err := extractTextFromResponse(responseBytes) + if err != nil { + ag.log.Error("failed to extract text from response", "error", err, "response_preview", string(responseBytes[:min(len(responseBytes), 500)])) + // Return raw response as fallback + return responseBytes, nil + } + + return []byte(text), nil +} + +// extractTextFromResponse parses common LLM response formats and extracts the text content. +func extractTextFromResponse(data []byte) (string, error) { + // Try to parse as generic JSON first + var genericResp map[string]interface{} + if err := json.Unmarshal(data, &genericResp); err != nil { + // Not JSON, return as string + return string(data), nil + } + + // Check for OpenAI chat completion format + if choices, ok := genericResp["choices"].([]interface{}); ok && len(choices) > 0 { + if firstChoice, ok := choices[0].(map[string]interface{}); ok { + // Chat completion: choices[0].message.content + if message, ok := firstChoice["message"].(map[string]interface{}); ok { + if content, ok := message["content"].(string); ok { + return content, nil + } + } + // Completion: choices[0].text + if text, ok := firstChoice["text"].(string); ok { + return text, nil + } + // Delta format for streaming (should not happen with stream: false) + if delta, ok := firstChoice["delta"].(map[string]interface{}); ok { + if content, ok := delta["content"].(string); ok { + return content, nil + } + } + } + } + + // Check for llama.cpp completion format + if content, ok := genericResp["content"].(string); ok { + return content, nil + } + + // Unknown format, return pretty-printed JSON + prettyJSON, err := json.MarshalIndent(genericResp, "", " ") + if err != nil { + return string(data), nil + } + return string(prettyJSON), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/agent/webagent.go b/agent/webagent.go new file mode 100644 index 0000000..ff6cd86 --- /dev/null +++ b/agent/webagent.go @@ -0,0 +1,32 @@ +package agent + +import ( + "fmt" +) + +// WebAgentB is a simple agent that applies formatting functions +type WebAgentB struct { + *AgentClient + sysprompt string +} + +// NewWebAgentB creates a WebAgentB that uses the given formatting function +func NewWebAgentB(client *AgentClient, sysprompt string) *WebAgentB { + return &WebAgentB{AgentClient: client, sysprompt: sysprompt} +} + +// Process applies the formatting function to raw output +func (a *WebAgentB) Process(args map[string]string, rawOutput []byte) []byte { + msg, err := a.FormMsg(a.sysprompt, + fmt.Sprintf("request:\n%+v\ntool response:\n%v", args, string(rawOutput))) + if err != nil { + a.Log().Error("failed to process the request", "error", err) + return []byte("failed to process the request; err: " + err.Error()) + } + resp, err := a.LLMRequest(msg) + if err != nil { + a.Log().Error("failed to process the request", "error", err) + return []byte("failed to process the request; err: " + err.Error()) + } + return resp +} @@ -16,9 +16,12 @@ import ( "log/slog" "net" "net/http" + "net/url" "os" "path" + "strconv" "strings" + "sync" "time" "github.com/neurosnap/sentences/english" @@ -47,10 +50,10 @@ var ( ragger *rag.RAG chunkParser ChunkParser lastToolCall *models.FuncCall - lastToolCallID string // Store the ID of the most recent tool call //nolint:unused // TTS_ENABLED conditionally uses this orator extra.Orator asr extra.STT + localModelsMu sync.RWMutex defaultLCPProps = map[string]float32{ "temperature": 0.8, "dry_multiplier": 0.0, @@ -84,19 +87,31 @@ func cleanNullMessages(messages []models.RoleMsg) []models.RoleMsg { return consolidateConsecutiveAssistantMessages(messages) } +func cleanToolCalls(messages []models.RoleMsg) []models.RoleMsg { + cleaned := make([]models.RoleMsg, 0, len(messages)) + for i, msg := range messages { + // recognize the message as the tool call and remove it + if msg.ToolCallID == "" { + cleaned = append(cleaned, msg) + } + // tool call in last msg should stay + if i == len(messages)-1 { + cleaned = append(cleaned, msg) + } + } + return consolidateConsecutiveAssistantMessages(cleaned) +} + // consolidateConsecutiveAssistantMessages merges consecutive assistant messages into a single message func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models.RoleMsg { if len(messages) == 0 { return messages } - consolidated := make([]models.RoleMsg, 0, len(messages)) currentAssistantMsg := models.RoleMsg{} isBuildingAssistantMsg := false - for i := 0; i < len(messages); i++ { msg := messages[i] - if msg.Role == cfg.AssistantRole || msg.Role == cfg.WriteNextMsgAsCompletionAgent { // If this is an assistant message, start or continue building if !isBuildingAssistantMsg { @@ -141,12 +156,10 @@ func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models consolidated = append(consolidated, msg) } } - // Don't forget the last assistant message if we were building one if isBuildingAssistantMsg { consolidated = append(consolidated, currentAssistantMsg) } - return consolidated } @@ -188,6 +201,72 @@ func createClient(connectTimeout time.Duration) *http.Client { } } +func warmUpModel() { + u, err := url.Parse(cfg.CurrentAPI) + if err != nil { + return + } + host := u.Hostname() + if host != "localhost" && host != "127.0.0.1" && host != "::1" { + return + } + // Check if model is already loaded + loaded, err := isModelLoaded(chatBody.Model) + if err != nil { + logger.Debug("failed to check model status", "model", chatBody.Model, "error", err) + // Continue with warmup attempt anyway + } + if loaded { + if err := notifyUser("model already loaded", "Model "+chatBody.Model+" is already loaded."); err != nil { + logger.Debug("failed to notify user", "error", err) + } + return + } + go func() { + var data []byte + var err error + if strings.HasSuffix(cfg.CurrentAPI, "/completion") { + // Old completion endpoint + req := models.NewLCPReq(".", chatBody.Model, nil, map[string]float32{ + "temperature": 0.8, + "dry_multiplier": 0.0, + "min_p": 0.05, + "n_predict": 0, + }, []string{}) + req.Stream = false + data, err = json.Marshal(req) + } else if strings.Contains(cfg.CurrentAPI, "/v1/chat/completions") { + // OpenAI-compatible chat endpoint + req := models.OpenAIReq{ + ChatBody: &models.ChatBody{ + Model: chatBody.Model, + Messages: []models.RoleMsg{ + {Role: "system", Content: "."}, + }, + Stream: false, + }, + Tools: nil, + } + data, err = json.Marshal(req) + } else { + // Unknown local endpoint, skip + return + } + if err != nil { + logger.Debug("failed to marshal warmup request", "error", err) + return + } + resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", bytes.NewReader(data)) + if err != nil { + logger.Debug("warmup request failed", "error", err) + return + } + resp.Body.Close() + // Start monitoring for model load completion + monitorModelLoad(chatBody.Model) + }() +} + func fetchLCPModelName() *models.LCPModels { //nolint resp, err := httpClient.Get(cfg.FetchModelNameAPI) @@ -210,6 +289,7 @@ func fetchLCPModelName() *models.LCPModels { return nil } chatBody.Model = path.Base(llmModel.Data[0].ID) + cfg.CurrentModel = chatBody.Model return &llmModel } @@ -274,64 +354,82 @@ func fetchLCPModels() ([]string, error) { return localModels, nil } -func sendMsgToLLM(body io.Reader) { - choseChunkParser() - - var req *http.Request - var err error +// fetchLCPModelsWithStatus returns the full LCPModels struct including status information. +func fetchLCPModelsWithStatus() (*models.LCPModels, error) { + resp, err := http.Get(cfg.FetchModelNameAPI) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err := fmt.Errorf("failed to fetch llama.cpp models; status: %s", resp.Status) + return nil, err + } + data := &models.LCPModels{} + if err := json.NewDecoder(resp.Body).Decode(data); err != nil { + return nil, err + } + return data, nil +} - // Capture and log the request body for debugging - if _, ok := body.(*io.LimitedReader); ok { - // If it's a LimitedReader, we need to handle it differently - logger.Debug("request body type is LimitedReader", "parser", chunkParser, "link", cfg.CurrentAPI) - req, err = http.NewRequest("POST", cfg.CurrentAPI, body) - if err != nil { - logger.Error("newreq error", "error", err) - if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { - logger.Error("failed to notify", "error", err) - } - streamDone <- true - return +// isModelLoaded checks if the given model ID is currently loaded in llama.cpp server. +func isModelLoaded(modelID string) (bool, error) { + models, err := fetchLCPModelsWithStatus() + if err != nil { + return false, err + } + for _, m := range models.Data { + if m.ID == modelID { + return m.Status.Value == "loaded", nil } - req.Header.Add("Accept", "application/json") - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Accept-Encoding", "gzip") - } else { - // For other reader types, capture and log the body content - bodyBytes, err := io.ReadAll(body) - if err != nil { - logger.Error("failed to read request body for logging", "error", err) - // Create request with original body if reading fails - req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes)) - if err != nil { - logger.Error("newreq error", "error", err) - if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { - logger.Error("failed to notify", "error", err) - } - streamDone <- true + } + return false, nil +} + +// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded. +func monitorModelLoad(modelID string) { + go func() { + timeout := time.After(2 * time.Minute) // max wait 2 minutes + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-timeout: + logger.Debug("model load monitoring timeout", "model", modelID) return - } - } else { - // Log the request body for debugging - logger.Debug("sending request to API", "api", cfg.CurrentAPI, "body", string(bodyBytes)) - // Create request with the captured body - req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes)) - if err != nil { - logger.Error("newreq error", "error", err) - if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { - logger.Error("failed to notify", "error", err) + case <-ticker.C: + loaded, err := isModelLoaded(modelID) + if err != nil { + logger.Debug("failed to check model status", "model", modelID, "error", err) + continue + } + if loaded { + if err := notifyUser("model loaded", "Model "+modelID+" is now loaded and ready."); err != nil { + logger.Debug("failed to notify user", "error", err) + } + return } - streamDone <- true - return } } + }() +} - req.Header.Add("Accept", "application/json") - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Accept-Encoding", "gzip") +// sendMsgToLLM expects streaming resp +func sendMsgToLLM(body io.Reader) { + choseChunkParser() + req, err := http.NewRequest("POST", cfg.CurrentAPI, body) + if err != nil { + logger.Error("newreq error", "error", err) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return } + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) + req.Header.Set("Accept-Encoding", "gzip") // nolint resp, err := httpClient.Do(req) if err != nil { @@ -396,6 +494,7 @@ func sendMsgToLLM(body io.Reader) { streamDone <- true break } + // // problem: this catches any mention of the word 'error' // Handle error messages in response content // example needed, since llm could use the word error in the normal msg // if string(line) != "" && strings.Contains(strings.ToLower(string(line)), "error") { @@ -422,7 +521,7 @@ func sendMsgToLLM(body io.Reader) { if chunk.FuncName != "" { lastToolCall.Name = chunk.FuncName // Store the tool call ID for the response - lastToolCallID = chunk.ToolID + lastToolCall.ID = chunk.ToolID } interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req @@ -604,20 +703,16 @@ out: Role: botPersona, Content: respText.String(), }) } - logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages)) for i, msg := range chatBody.Messages { logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) } - // // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template cleanChatBody() - logger.Debug("chatRound: after cleanChatBody", "messages_after_clean", len(chatBody.Messages)) for i, msg := range chatBody.Messages { logger.Debug("chatRound: after cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) } - colorText() updateStatusLine() // bot msg is done; @@ -631,20 +726,84 @@ out: // cleanChatBody removes messages with null or empty content to prevent API issues func cleanChatBody() { - if chatBody != nil && chatBody.Messages != nil { - originalLen := len(chatBody.Messages) - logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen) - for i, msg := range chatBody.Messages { - logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) - } + if chatBody == nil || chatBody.Messages == nil { + return + } + originalLen := len(chatBody.Messages) + logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen) + for i, msg := range chatBody.Messages { + logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } + // TODO: consider case where we keep tool requests + // /completion msg where part meant for user and other part tool call + chatBody.Messages = cleanToolCalls(chatBody.Messages) + chatBody.Messages = cleanNullMessages(chatBody.Messages) + logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages)) + for i, msg := range chatBody.Messages { + logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } +} - chatBody.Messages = cleanNullMessages(chatBody.Messages) +// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings. +func convertJSONToMapStringString(jsonStr string) (map[string]string, error) { + var raw map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil { + return nil, err + } + result := make(map[string]string, len(raw)) + for k, v := range raw { + switch val := v.(type) { + case string: + result[k] = val + case float64: + result[k] = strconv.FormatFloat(val, 'f', -1, 64) + case int, int64, int32: + // json.Unmarshal converts numbers to float64, but handle other integer types if they appear + result[k] = fmt.Sprintf("%v", val) + case bool: + result[k] = strconv.FormatBool(val) + case nil: + result[k] = "" + default: + result[k] = fmt.Sprintf("%v", val) + } + } + return result, nil +} - logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages)) - for i, msg := range chatBody.Messages { - logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) +// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings. +func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) { + type tempFuncCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Args map[string]interface{} `json:"args"` + } + var temp tempFuncCall + if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil { + return nil, err + } + fc := &models.FuncCall{ + ID: temp.ID, + Name: temp.Name, + Args: make(map[string]string, len(temp.Args)), + } + for k, v := range temp.Args { + switch val := v.(type) { + case string: + fc.Args[k] = val + case float64: + fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64) + case int, int64, int32: + fc.Args[k] = fmt.Sprintf("%v", val) + case bool: + fc.Args[k] = strconv.FormatBool(val) + case nil: + fc.Args[k] = "" + default: + fc.Args[k] = fmt.Sprintf("%v", val) } } + return fc, nil } func findCall(msg, toolCall string, tv *tview.TextView) { @@ -652,30 +811,28 @@ func findCall(msg, toolCall string, tv *tview.TextView) { if toolCall != "" { // HTML-decode the tool call string to handle encoded characters like < -> <= decodedToolCall := html.UnescapeString(toolCall) - openAIToolMap := make(map[string]string) - // respect tool call - if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil { + openAIToolMap, err := convertJSONToMapStringString(decodedToolCall) + if err != nil { logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err) + // Ensure lastToolCall.ID is set for the error response (already set from chunk) // Send error response to LLM so it can retry or handle the error toolResponseMsg := models.RoleMsg{ Role: cfg.ToolRole, Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err), - ToolCallID: lastToolCallID, // Use the stored tool call ID + ToolCallID: lastToolCall.ID, // Use the stored tool call ID } chatBody.Messages = append(chatBody.Messages, toolResponseMsg) - // Clear the stored tool call ID after using it - lastToolCallID = "" + // Clear the stored tool call ID after using it (no longer needed) // Trigger the assistant to continue processing with the error message chatRound("", cfg.AssistantRole, tv, false, false) return } lastToolCall.Args = openAIToolMap fc = lastToolCall - // Ensure lastToolCallID is set if it's available in the tool call - if lastToolCallID == "" && len(openAIToolMap) > 0 { - // Attempt to extract ID from the parsed tool call if not already set + // Set lastToolCall.ID from parsed tool call ID if available + if len(openAIToolMap) > 0 { if id, exists := openAIToolMap["id"]; exists { - lastToolCallID = id + lastToolCall.ID = id } } } else { @@ -688,7 +845,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) { jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) // HTML-decode the JSON string to handle encoded characters like < -> <= decodedJsStr := html.UnescapeString(jsStr) - if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil { + var err error + fc, err = unmarshalFuncCall(decodedJsStr) + if err != nil { logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr) // Send error response to LLM so it can retry or handle the error toolResponseMsg := models.RoleMsg{ @@ -701,28 +860,40 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatRound("", cfg.AssistantRole, tv, false, false) return } + // Update lastToolCall with parsed function call + lastToolCall.ID = fc.ID + lastToolCall.Name = fc.Name + lastToolCall.Args = fc.Args + } + // we got here => last msg recognized as a tool call (correct or not) + // make sure it has ToolCallID + if chatBody.Messages[len(chatBody.Messages)-1].ToolCallID == "" { + chatBody.Messages[len(chatBody.Messages)-1].ToolCallID = randString(6) + } + // Ensure lastToolCall.ID is set, fallback to assistant message's ToolCallID + if lastToolCall.ID == "" { + lastToolCall.ID = chatBody.Messages[len(chatBody.Messages)-1].ToolCallID } // call a func - f, ok := fnMap[fc.Name] + _, ok := fnMap[fc.Name] if !ok { m := fc.Name + " is not implemented" // Create tool response message with the proper tool_call_id toolResponseMsg := models.RoleMsg{ Role: cfg.ToolRole, Content: m, - ToolCallID: lastToolCallID, // Use the stored tool call ID + ToolCallID: lastToolCall.ID, // Use the stored tool call ID } chatBody.Messages = append(chatBody.Messages, toolResponseMsg) logger.Debug("findCall: added tool not implemented response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) // Clear the stored tool call ID after using it - lastToolCallID = "" - + lastToolCall.ID = "" // Trigger the assistant to continue processing with the new tool response // by calling chatRound with empty content to continue the assistant's response chatRound("", cfg.AssistantRole, tv, false, false) return } - resp := f(fc.Args) + resp := callToolWithAgent(fc.Name, fc.Args) toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc) fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", @@ -731,12 +902,12 @@ func findCall(msg, toolCall string, tv *tview.TextView) { toolResponseMsg := models.RoleMsg{ Role: cfg.ToolRole, Content: toolMsg, - ToolCallID: lastToolCallID, // Use the stored tool call ID + ToolCallID: lastToolCall.ID, // Use the stored tool call ID } chatBody.Messages = append(chatBody.Messages, toolResponseMsg) logger.Debug("findCall: added actual tool response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) // Clear the stored tool call ID after using it - lastToolCallID = "" + lastToolCall.ID = "" // Trigger the assistant to continue processing with the new tool response // by calling chatRound with empty content to continue the assistant's response chatRound("", cfg.AssistantRole, tv, false, false) @@ -756,7 +927,7 @@ func chatToTextSlice(showSys bool) []string { func chatToText(showSys bool) string { s := chatToTextSlice(showSys) - return strings.Join(s, "") + return strings.Join(s, "\n") } func removeThinking(chatBody *models.ChatBody) { @@ -835,19 +1006,30 @@ func updateModelLists() { } } // if llama.cpp started after gf-lt? + localModelsMu.Lock() LocalModels, err = fetchLCPModels() + localModelsMu.Unlock() if err != nil { logger.Warn("failed to fetch llama.cpp models", "error", err) } } -func updateModelListsTicker() { - updateModelLists() // run on the start - ticker := time.NewTicker(time.Minute * 1) - for { - <-ticker.C - updateModelLists() +func refreshLocalModelsIfEmpty() { + localModelsMu.RLock() + if len(LocalModels) > 0 { + localModelsMu.RUnlock() + return + } + localModelsMu.RUnlock() + // try to fetch + models, err := fetchLCPModels() + if err != nil { + logger.Warn("failed to fetch llama.cpp models", "error", err) + return } + localModelsMu.Lock() + LocalModels = models + localModelsMu.Unlock() } func init() { @@ -903,12 +1085,12 @@ func init() { cluedoState = extra.CluedoPrepCards(playerOrder) } choseChunkParser() - httpClient = createClient(time.Second * 15) + httpClient = createClient(time.Second * 90) if cfg.TTS_ENABLED { orator = extra.NewOrator(logger, cfg) } if cfg.STT_ENABLED { asr = extra.NewSTT(logger, cfg) } - go updateModelListsTicker() + go updateModelLists() } diff --git a/bot_test.go b/bot_test.go index 2d59c3c..d2956a9 100644 --- a/bot_test.go +++ b/bot_test.go @@ -152,4 +152,138 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) { } }) } +} + +func TestUnmarshalFuncCall(t *testing.T) { + tests := []struct { + name string + jsonStr string + want *models.FuncCall + wantErr bool + }{ + { + name: "simple websearch with numeric limit", + jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`, + want: &models.FuncCall{ + Name: "websearch", + Args: map[string]string{"query": "current weather in London", "limit": "3"}, + }, + wantErr: false, + }, + { + name: "string limit", + jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`, + want: &models.FuncCall{ + Name: "websearch", + Args: map[string]string{"query": "test", "limit": "5"}, + }, + wantErr: false, + }, + { + name: "boolean arg", + jsonStr: `{"name": "test", "args": {"flag": true}}`, + want: &models.FuncCall{ + Name: "test", + Args: map[string]string{"flag": "true"}, + }, + wantErr: false, + }, + { + name: "null arg", + jsonStr: `{"name": "test", "args": {"opt": null}}`, + want: &models.FuncCall{ + Name: "test", + Args: map[string]string{"opt": ""}, + }, + wantErr: false, + }, + { + name: "float arg", + jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`, + want: &models.FuncCall{ + Name: "test", + Args: map[string]string{"ratio": "0.5"}, + }, + wantErr: false, + }, + { + name: "invalid JSON", + jsonStr: `{invalid}`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := unmarshalFuncCall(tt.jsonStr) + if (err != nil) != tt.wantErr { + t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if got.Name != tt.want.Name { + t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name) + } + if len(got.Args) != len(tt.want.Args) { + t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args)) + } + for k, v := range tt.want.Args { + if got.Args[k] != v { + t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v) + } + } + }) + } +} + +func TestConvertJSONToMapStringString(t *testing.T) { + tests := []struct { + name string + jsonStr string + want map[string]string + wantErr bool + }{ + { + name: "simple map", + jsonStr: `{"query": "weather", "limit": 5}`, + want: map[string]string{"query": "weather", "limit": "5"}, + wantErr: false, + }, + { + name: "boolean and null", + jsonStr: `{"flag": true, "opt": null}`, + want: map[string]string{"flag": "true", "opt": ""}, + wantErr: false, + }, + { + name: "invalid JSON", + jsonStr: `{invalid`, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := convertJSONToMapStringString(tt.jsonStr) + if (err != nil) != tt.wantErr { + t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if len(got) != len(tt.want) { + t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want)) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v) + } + } + }) + } }
\ No newline at end of file diff --git a/config/config.go b/config/config.go index eef8035..5b7cc35 100644 --- a/config/config.go +++ b/config/config.go @@ -12,7 +12,7 @@ type Config struct { ChatAPI string `toml:"ChatAPI"` CompletionAPI string `toml:"CompletionAPI"` CurrentAPI string - CurrentProvider string + CurrentModel string `toml:"CurrentModel"` APIMap map[string]string FetchModelNameAPI string `toml:"FetchModelNameAPI"` // ToolsAPI list? diff --git a/helpfuncs.go b/helpfuncs.go index df49ae5..30d9967 100644 --- a/helpfuncs.go +++ b/helpfuncs.go @@ -8,8 +8,20 @@ import ( "os" "path" "strings" + "unicode" + + "math/rand/v2" ) +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true +} + func colorText() { text := textView.GetText(false) quoteReplacer := strings.NewReplacer( @@ -63,7 +75,7 @@ func colorText() { } func updateStatusLine() { - position.SetText(makeStatusLine()) + statusLineWidget.SetText(makeStatusLine()) helpView.SetText(fmt.Sprintf(helpText, makeStatusLine())) } @@ -229,3 +241,13 @@ func makeStatusLine() string { isRecording, persona, botPersona, injectRole) return statusLine + imageInfo + shellModeInfo } + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.IntN(len(letters))] + } + return string(b) +} @@ -122,7 +122,7 @@ func (lcp LCPCompletion) FormMsg(msg, role string, resume bool) (io.Reader, erro logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) } } - if cfg.ToolUse && !resume { + if cfg.ToolUse && !resume && role == cfg.UserRole { // add to chat body chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) } @@ -358,7 +358,7 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader logger.Debug("DeepSeekerCompletion: RAG message added to chat body", "message_count", len(chatBody.Messages)) } } - if cfg.ToolUse && !resume { + if cfg.ToolUse && !resume && role == cfg.UserRole { // add to chat body chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) } @@ -420,7 +420,7 @@ func (ds DeepSeekerChat) GetToken() string { func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, error) { logger.Debug("formmsg deepseekerchat", "link", cfg.CurrentAPI) - if cfg.ToolUse && !resume { + if cfg.ToolUse && !resume && role == cfg.UserRole { // prompt += "\n" + cfg.ToolRole + ":\n" + toolSysMsg // add to chat body chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) @@ -516,7 +516,7 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader logger.Debug("RAG message added to chat body", "message_count", len(chatBody.Messages)) } } - if cfg.ToolUse && !resume { + if cfg.ToolUse && !resume && role == cfg.UserRole { // add to chat body chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) } @@ -3,7 +3,6 @@ package main import ( "flag" "strconv" - "unicode" "github.com/rivo/tview" ) @@ -23,15 +22,6 @@ var ( focusSwitcher = map[tview.Primitive]tview.Primitive{} ) -func isASCII(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] > unicode.MaxASCII { - return false - } - } - return true -} - func main() { apiPort := flag.Int("port", 0, "port to host api") flag.Parse() diff --git a/props_table.go b/props_table.go index dd359f4..0c49056 100644 --- a/props_table.go +++ b/props_table.go @@ -5,11 +5,14 @@ import ( "slices" "strconv" "strings" + "sync" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" ) +var _ = sync.RWMutex{} + // Define constants for cell types const ( CellTypeCheckbox = "checkbox" @@ -50,6 +53,7 @@ func makePropsTable(props map[string]float32) *tview.Table { row++ // Store cell data for later use in selection functions cellData := make(map[string]*CellData) + var modelCellID string // will be set for the model selection row // Helper function to add a checkbox-like row addCheckboxRow := func(label string, initialValue bool, onChange func(bool)) { table.SetCell(row, 0, @@ -130,23 +134,60 @@ func makePropsTable(props map[string]float32) *tview.Table { addListPopupRow("Set log level", logLevels, GetLogLevel(), func(option string) { setLogLevel(option) }) - // Prepare API links dropdown - insert current API at the beginning - apiLinks := slices.Insert(cfg.ApiLinks, 0, cfg.CurrentAPI) + // Helper function to get model list for a given API + getModelListForAPI := func(api string) []string { + if strings.Contains(api, "api.deepseek.com/") { + return []string{"deepseek-chat", "deepseek-reasoner"} + } else if strings.Contains(api, "openrouter.ai") { + return ORFreeModels + } + // Assume local llama.cpp + refreshLocalModelsIfEmpty() + localModelsMu.RLock() + defer localModelsMu.RUnlock() + return LocalModels + } + var modelRowIndex int // will be set before model row is added + // Prepare API links dropdown - ensure current API is first, avoid duplicates + apiLinks := make([]string, 0, len(cfg.ApiLinks)+1) + apiLinks = append(apiLinks, cfg.CurrentAPI) + for _, api := range cfg.ApiLinks { + if api != cfg.CurrentAPI { + apiLinks = append(apiLinks, api) + } + } addListPopupRow("Select an api", apiLinks, cfg.CurrentAPI, func(option string) { cfg.CurrentAPI = option + // Update model list based on new API + newModelList := getModelListForAPI(cfg.CurrentAPI) + if modelCellID != "" { + if data := cellData[modelCellID]; data != nil { + data.Options = newModelList + } + } + // Ensure chatBody.Model is in the new list; if not, set to first available model + if len(newModelList) > 0 && !slices.Contains(newModelList, chatBody.Model) { + chatBody.Model = newModelList[0] + cfg.CurrentModel = chatBody.Model + // Update the displayed cell text - need to find model row + // Search for model row by label + for r := 0; r < table.GetRowCount(); r++ { + if cell := table.GetCell(r, 0); cell != nil && cell.Text == "Select a model" { + if valueCell := table.GetCell(r, 1); valueCell != nil { + valueCell.SetText(chatBody.Model) + } + break + } + } + } }) - var modelList []string - // INFO: modelList is chosen based on current api link - if strings.Contains(cfg.CurrentAPI, "api.deepseek.com/") { - modelList = []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"} - } else if strings.Contains(cfg.CurrentAPI, "opentouter.ai") { - modelList = ORFreeModels - } else { // would match on localhost but what if llama.cpp served non localy? - modelList = LocalModels - } // Prepare model list dropdown + modelRowIndex = row + modelCellID = fmt.Sprintf("listpopup_%d", modelRowIndex) + modelList := getModelListForAPI(cfg.CurrentAPI) addListPopupRow("Select a model", modelList, chatBody.Model, func(option string) { chatBody.Model = option + cfg.CurrentModel = chatBody.Model }) // Role selection dropdown addListPopupRow("Write next message as", listRolesWithUser(), cfg.WriteNextMsgAs, func(option string) { @@ -228,11 +269,53 @@ func makePropsTable(props map[string]float32) *tview.Table { listPopupCellID := fmt.Sprintf("listpopup_%d", selectedRow) if cellData[listPopupCellID] != nil && cellData[listPopupCellID].Type == CellTypeListPopup { data := cellData[listPopupCellID] - if onChange, ok := data.OnChange.(func(string)); ok && data.Options != nil { + if onChange, ok := data.OnChange.(func(string)); ok { + // Get label for context + labelCell := table.GetCell(selectedRow, 0) + label := "item" + if labelCell != nil { + label = labelCell.Text + } + + // For model selection, always compute fresh options from current API + if label == "Select a model" { + freshOptions := getModelListForAPI(cfg.CurrentAPI) + data.Options = freshOptions + // Also update the cell data map + cellData[listPopupCellID].Options = freshOptions + } + + // Handle nil options + if data.Options == nil { + logger.Error("options list is nil for", "label", label) + if err := notifyUser("Configuration error", "Options list is nil for "+label); err != nil { + logger.Error("failed to send notification", "error", err) + } + return + } + + // Check for empty options list + if len(data.Options) == 0 { + logger.Warn("empty options list for", "label", label, "api", cfg.CurrentAPI, "localModelsLen", len(LocalModels), "orModelsLen", len(ORFreeModels)) + message := "No options available for " + label + if label == "Select a model" { + if strings.Contains(cfg.CurrentAPI, "openrouter.ai") { + message = "No OpenRouter models available. Check token and connection." + } else if strings.Contains(cfg.CurrentAPI, "api.deepseek.com") { + message = "DeepSeek models should be available. Please report bug." + } else { + message = "No llama.cpp models loaded. Ensure llama.cpp server is running with models." + } + } + if err := notifyUser("Empty list", message); err != nil { + logger.Error("failed to send notification", "error", err) + } + return + } // Create a list primitive apiList := tview.NewList().ShowSecondaryText(false). SetSelectedBackgroundColor(tcell.ColorGray) - apiList.SetTitle("Select an API").SetBorder(true) + apiList.SetTitle("Select " + label).SetBorder(true) for i, api := range data.Options { if api == cell.Text { apiList.SetCurrentItem(i) @@ -23,43 +23,92 @@ func makeChatTable(chatMap map[string]models.Chat) *tview.Table { chatList[i] = name i++ } - rows, cols := len(chatMap), len(actions)+2 + + // Add 1 extra row for header + rows, cols := len(chatMap)+1, len(actions)+4 // +2 for name, +2 for timestamps chatActTable := tview.NewTable(). SetBorders(true) - for r := 0; r < rows; r++ { + + // Add header row (row 0) + for c := 0; c < cols; c++ { + color := tcell.ColorWhite + headerText := "" + switch c { + case 0: + headerText = "Chat Name" + case 1: + headerText = "Preview" + case 2: + headerText = "Created At" + case 3: + headerText = "Updated At" + default: + headerText = actions[c-4] + } + chatActTable.SetCell(0, c, + tview.NewTableCell(headerText). + SetSelectable(false). + SetTextColor(color). + SetAlign(tview.AlignCenter). + SetAttributes(tcell.AttrBold)) + } + + // Add data rows (starting from row 1) + for r := 0; r < rows-1; r++ { // rows-1 because we added a header row for c := 0; c < cols; c++ { color := tcell.ColorWhite switch c { case 0: - chatActTable.SetCell(r, c, + chatActTable.SetCell(r+1, c, // +1 to account for header row tview.NewTableCell(chatList[r]). SetSelectable(false). SetTextColor(color). SetAlign(tview.AlignCenter)) case 1: - chatActTable.SetCell(r, c, + chatActTable.SetCell(r+1, c, // +1 to account for header row tview.NewTableCell(chatMap[chatList[r]].Msgs[len(chatMap[chatList[r]].Msgs)-30:]). SetSelectable(false). SetTextColor(color). SetAlign(tview.AlignCenter)) + case 2: + // Created At column + chatActTable.SetCell(r+1, c, // +1 to account for header row + tview.NewTableCell(chatMap[chatList[r]].CreatedAt.Format("2006-01-02 15:04")). + SetSelectable(false). + SetTextColor(color). + SetAlign(tview.AlignCenter)) + case 3: + // Updated At column + chatActTable.SetCell(r+1, c, // +1 to account for header row + tview.NewTableCell(chatMap[chatList[r]].UpdatedAt.Format("2006-01-02 15:04")). + SetSelectable(false). + SetTextColor(color). + SetAlign(tview.AlignCenter)) default: - chatActTable.SetCell(r, c, - tview.NewTableCell(actions[c-2]). + chatActTable.SetCell(r+1, c, // +1 to account for header row + tview.NewTableCell(actions[c-4]). // Adjusted offset to account for 2 new timestamp columns SetTextColor(color). SetAlign(tview.AlignCenter)) } } } - chatActTable.Select(0, 0).SetSelectable(true, true).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { + chatActTable.Select(1, 0).SetSelectable(true, true).SetFixed(1, 1).SetDoneFunc(func(key tcell.Key) { if key == tcell.KeyEsc || key == tcell.KeyF1 || key == tcell.Key('x') { pages.RemovePage(historyPage) return } }).SetSelectedFunc(func(row int, column int) { + // Skip header row (row 0) for selection + if row == 0 { + // If user clicks on header, just return without action + chatActTable.Select(1, column) // Move selection to first data row + return + } + tc := chatActTable.GetCell(row, column) tc.SetTextColor(tcell.ColorRed) chatActTable.SetSelectable(false, false) - selectedChat := chatList[row] + selectedChat := chatList[row-1] // -1 to account for header row defer pages.RemovePage(historyPage) switch tc.Text { case "load": @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "gf-lt/agent" "gf-lt/extra" "gf-lt/models" "io" @@ -12,6 +13,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -125,7 +127,9 @@ under the topic: Adam's number is stored: </example_response> After that you are free to respond to the user. ` - basicCard = &models.CharCard{ + webSearchSysPrompt = `Summarize the web search results, extracting key information and presenting a concise answer. Provide sources and URLs where relevant.` + readURLSysPrompt = `Extract and summarize the content from the webpage. Provide key information, main points, and any relevant details.` + basicCard = &models.CharCard{ SysPrompt: basicSysMsg, FirstMsg: defaultFirstMsg, Role: "", @@ -140,8 +144,43 @@ After that you are free to respond to the user. // sysMap = map[string]string{"basic_sys": basicSysMsg, "tool_sys": toolSysMsg} sysMap = map[string]*models.CharCard{"basic_sys": basicCard} sysLabels = []string{"basic_sys"} + + webAgentClient *agent.AgentClient + webAgentClientOnce sync.Once + webAgentsOnce sync.Once ) +// getWebAgentClient returns a singleton AgentClient for web agents. +func getWebAgentClient() *agent.AgentClient { + webAgentClientOnce.Do(func() { + if cfg == nil { + panic("cfg not initialized") + } + if logger == nil { + panic("logger not initialized") + } + getToken := func() string { + if chunkParser == nil { + return "" + } + return chunkParser.GetToken() + } + webAgentClient = agent.NewAgentClient(cfg, *logger, getToken) + }) + return webAgentClient +} + +// registerWebAgents registers WebAgentB instances for websearch and read_url tools. +func registerWebAgents() { + webAgentsOnce.Do(func() { + client := getWebAgentClient() + // Register websearch agent + agent.Register("websearch", agent.NewWebAgentB(client, webSearchSysPrompt)) + // Register read_url agent + agent.Register("read_url", agent.NewWebAgentB(client, readURLSysPrompt)) + }) +} + // web search (depends on extra server) func websearch(args map[string]string) []byte { // make http request return bytes @@ -596,7 +635,6 @@ var globalTodoList = TodoList{ Items: []TodoItem{}, } - // Todo Management Tools func todoCreate(args map[string]string) []byte { task, ok := args["task"] @@ -848,6 +886,20 @@ var fnMap = map[string]fnSig{ "todo_delete": todoDelete, } +// callToolWithAgent calls the tool and applies any registered agent. +func callToolWithAgent(name string, args map[string]string) []byte { + registerWebAgents() + f, ok := fnMap[name] + if !ok { + return []byte(fmt.Sprintf("tool %s not found", name)) + } + raw := f(args) + if a := agent.Get(name); a != nil { + return a.Process(args, raw) + } + return raw +} + // openai style def var baseTools = []models.Tool{ // websearch @@ -12,26 +12,30 @@ import ( "path" "strconv" "strings" + "sync" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" ) +var _ = sync.RWMutex{} + var ( - app *tview.Application - pages *tview.Pages - textArea *tview.TextArea - editArea *tview.TextArea - textView *tview.TextView - position *tview.TextView - helpView *tview.TextView - flex *tview.Flex - imgView *tview.Image - defaultImage = "sysprompts/llama.png" - indexPickWindow *tview.InputField - renameWindow *tview.InputField - roleEditWindow *tview.InputField - fullscreenMode bool + app *tview.Application + pages *tview.Pages + textArea *tview.TextArea + editArea *tview.TextArea + textView *tview.TextView + statusLineWidget *tview.TextView + helpView *tview.TextView + flex *tview.Flex + imgView *tview.Image + defaultImage = "sysprompts/llama.png" + indexPickWindow *tview.InputField + renameWindow *tview.InputField + roleEditWindow *tview.InputField + fullscreenMode bool + positionVisible bool = true // pages historyPage = "historyPage" agentPage = "agentPage" @@ -87,6 +91,8 @@ var ( [yellow]Alt+1[white]: toggle shell mode (execute commands locally) [yellow]Alt+4[white]: edit msg role [yellow]Alt+5[white]: toggle system and tool messages display +[yellow]Alt+6[white]: toggle status line visibility +[yellow]Alt+9[white]: warm up (load) selected llama.cpp model === scrolling chat window (some keys similar to vim) === [yellow]arrows up/down and j/k[white]: scroll up and down @@ -171,6 +177,26 @@ func toggleShellMode() { updateStatusLine() } +func updateFlexLayout() { + if fullscreenMode { + // flex already contains only focused widget; do nothing + return + } + flex.Clear() + flex.AddItem(textView, 0, 40, false) + flex.AddItem(textArea, 0, 10, false) + if positionVisible { + flex.AddItem(statusLineWidget, 0, 2, false) + } + // Keep focus on currently focused widget + focused := app.GetFocus() + if focused == textView { + app.SetFocus(textView) + } else { + app.SetFocus(textArea) + } +} + func executeCommandAndDisplay(cmdText string) { // Parse the command (split by spaces, but handle quoted arguments) cmdParts := parseCommand(cmdText) @@ -456,8 +482,10 @@ func init() { // flex = tview.NewFlex().SetDirection(tview.FlexRow). AddItem(textView, 0, 40, false). - AddItem(textArea, 0, 10, true). // Restore original height - AddItem(position, 0, 2, false) + AddItem(textArea, 0, 10, true) // Restore original height + if positionVisible { + flex.AddItem(statusLineWidget, 0, 2, false) + } // textView.SetBorder(true).SetTitle("chat") textView.SetDoneFunc(func(key tcell.Key) { if key == tcell.KeyEnter { @@ -516,14 +544,16 @@ func init() { }) focusSwitcher[textArea] = textView focusSwitcher[textView] = textArea - position = tview.NewTextView(). + statusLineWidget = tview.NewTextView(). SetDynamicColors(true). SetTextAlign(tview.AlignCenter) // Initially set up flex without search bar flex = tview.NewFlex().SetDirection(tview.FlexRow). AddItem(textView, 0, 40, false). - AddItem(textArea, 0, 10, true). // Restore original height - AddItem(position, 0, 2, false) + AddItem(textArea, 0, 10, true) // Restore original height + if positionVisible { + flex.AddItem(statusLineWidget, 0, 2, false) + } editArea = tview.NewTextArea(). SetPlaceholder("Replace msg...") editArea.SetBorder(true).SetTitle("input") @@ -749,6 +779,14 @@ func init() { textView.SetText(chatToText(cfg.ShowSys)) colorText() } + if event.Key() == tcell.KeyRune && event.Rune() == '6' && event.Modifiers()&tcell.ModAlt != 0 { + // toggle status line visibility + if name, _ := pages.GetFrontPage(); name != "main" { + return event + } + positionVisible = !positionVisible + updateFlexLayout() + } if event.Key() == tcell.KeyF1 { // chatList, err := loadHistoryChats() chatList, err := store.GetChatByChar(cfg.AssistantRole) @@ -841,16 +879,7 @@ func init() { } } else { // focused is the fullscreened widget here - flex.Clear(). - AddItem(textView, 0, 40, false). - AddItem(textArea, 0, 10, false). - AddItem(position, 0, 2, false) - - if focused == textView { - app.SetFocus(textView) - } else { // default to textArea - app.SetFocus(textArea) - } + updateFlexLayout() } return nil } @@ -958,13 +987,17 @@ func init() { if len(ORFreeModels) > 0 { currentORModelIndex = (currentORModelIndex + 1) % len(ORFreeModels) chatBody.Model = ORFreeModels[currentORModelIndex] + cfg.CurrentModel = chatBody.Model } updateStatusLine() } else { + localModelsMu.RLock() if len(LocalModels) > 0 { currentLocalModelIndex = (currentLocalModelIndex + 1) % len(LocalModels) chatBody.Model = LocalModels[currentLocalModelIndex] + cfg.CurrentModel = chatBody.Model } + localModelsMu.RUnlock() updateStatusLine() // // For non-OpenRouter APIs, use the old logic // go func() { @@ -1210,6 +1243,14 @@ func init() { toggleShellMode() return nil } + if event.Key() == tcell.KeyRune && event.Modifiers() == tcell.ModAlt && event.Rune() == '9' { + // Warm up (load) the currently selected model + go warmUpModel() + if err := notifyUser("model warmup", "loading model: "+chatBody.Model); err != nil { + logger.Debug("failed to notify user", "error", err) + } + return nil + } // cannot send msg in editMode or botRespMode if event.Key() == tcell.KeyEscape && !editMode && !botRespMode { msgText := textArea.GetText() |
