diff options
| author | Grail Finder <wohilas@gmail.com> | 2025-12-21 11:39:36 +0300 |
|---|---|---|
| committer | Grail Finder <wohilas@gmail.com> | 2025-12-21 11:39:36 +0300 |
| commit | 75fde2a575697f8f46ee9676c0ed228e5315a4e5 (patch) | |
| tree | 64e02a6afef049eb2ca79a3a5d2b0beb8ba26385 /bot.go | |
| parent | 1ca75a00642c4e0a6eea3117e3b4ebaacfdcfa7a (diff) | |
| parent | 5525c946613a6f726cd116d79f1505a63ab25806 (diff) | |
Merge branch 'master' into doc/tutorial
Diffstat (limited to 'bot.go')
| -rw-r--r-- | bot.go | 378 |
1 files changed, 280 insertions, 98 deletions
@@ -16,9 +16,12 @@ import ( "log/slog" "net" "net/http" + "net/url" "os" "path" + "strconv" "strings" + "sync" "time" "github.com/neurosnap/sentences/english" @@ -47,10 +50,10 @@ var ( ragger *rag.RAG chunkParser ChunkParser lastToolCall *models.FuncCall - lastToolCallID string // Store the ID of the most recent tool call //nolint:unused // TTS_ENABLED conditionally uses this orator extra.Orator asr extra.STT + localModelsMu sync.RWMutex defaultLCPProps = map[string]float32{ "temperature": 0.8, "dry_multiplier": 0.0, @@ -84,19 +87,31 @@ func cleanNullMessages(messages []models.RoleMsg) []models.RoleMsg { return consolidateConsecutiveAssistantMessages(messages) } +func cleanToolCalls(messages []models.RoleMsg) []models.RoleMsg { + cleaned := make([]models.RoleMsg, 0, len(messages)) + for i, msg := range messages { + // recognize the message as the tool call and remove it + if msg.ToolCallID == "" { + cleaned = append(cleaned, msg) + } + // tool call in last msg should stay + if i == len(messages)-1 { + cleaned = append(cleaned, msg) + } + } + return consolidateConsecutiveAssistantMessages(cleaned) +} + // consolidateConsecutiveAssistantMessages merges consecutive assistant messages into a single message func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models.RoleMsg { if len(messages) == 0 { return messages } - consolidated := make([]models.RoleMsg, 0, len(messages)) currentAssistantMsg := models.RoleMsg{} isBuildingAssistantMsg := false - for i := 0; i < len(messages); i++ { msg := messages[i] - if msg.Role == cfg.AssistantRole || msg.Role == cfg.WriteNextMsgAsCompletionAgent { // If this is an assistant message, start or continue building if !isBuildingAssistantMsg { @@ -141,12 +156,10 @@ func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models consolidated = append(consolidated, msg) } } - // Don't forget the last assistant message if we were building one if isBuildingAssistantMsg { consolidated = append(consolidated, currentAssistantMsg) } - return consolidated } @@ -188,6 +201,72 @@ func createClient(connectTimeout time.Duration) *http.Client { } } +func warmUpModel() { + u, err := url.Parse(cfg.CurrentAPI) + if err != nil { + return + } + host := u.Hostname() + if host != "localhost" && host != "127.0.0.1" && host != "::1" { + return + } + // Check if model is already loaded + loaded, err := isModelLoaded(chatBody.Model) + if err != nil { + logger.Debug("failed to check model status", "model", chatBody.Model, "error", err) + // Continue with warmup attempt anyway + } + if loaded { + if err := notifyUser("model already loaded", "Model "+chatBody.Model+" is already loaded."); err != nil { + logger.Debug("failed to notify user", "error", err) + } + return + } + go func() { + var data []byte + var err error + if strings.HasSuffix(cfg.CurrentAPI, "/completion") { + // Old completion endpoint + req := models.NewLCPReq(".", chatBody.Model, nil, map[string]float32{ + "temperature": 0.8, + "dry_multiplier": 0.0, + "min_p": 0.05, + "n_predict": 0, + }, []string{}) + req.Stream = false + data, err = json.Marshal(req) + } else if strings.Contains(cfg.CurrentAPI, "/v1/chat/completions") { + // OpenAI-compatible chat endpoint + req := models.OpenAIReq{ + ChatBody: &models.ChatBody{ + Model: chatBody.Model, + Messages: []models.RoleMsg{ + {Role: "system", Content: "."}, + }, + Stream: false, + }, + Tools: nil, + } + data, err = json.Marshal(req) + } else { + // Unknown local endpoint, skip + return + } + if err != nil { + logger.Debug("failed to marshal warmup request", "error", err) + return + } + resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", bytes.NewReader(data)) + if err != nil { + logger.Debug("warmup request failed", "error", err) + return + } + resp.Body.Close() + // Start monitoring for model load completion + monitorModelLoad(chatBody.Model) + }() +} + func fetchLCPModelName() *models.LCPModels { //nolint resp, err := httpClient.Get(cfg.FetchModelNameAPI) @@ -210,6 +289,7 @@ func fetchLCPModelName() *models.LCPModels { return nil } chatBody.Model = path.Base(llmModel.Data[0].ID) + cfg.CurrentModel = chatBody.Model return &llmModel } @@ -274,64 +354,82 @@ func fetchLCPModels() ([]string, error) { return localModels, nil } -func sendMsgToLLM(body io.Reader) { - choseChunkParser() - - var req *http.Request - var err error +// fetchLCPModelsWithStatus returns the full LCPModels struct including status information. +func fetchLCPModelsWithStatus() (*models.LCPModels, error) { + resp, err := http.Get(cfg.FetchModelNameAPI) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + err := fmt.Errorf("failed to fetch llama.cpp models; status: %s", resp.Status) + return nil, err + } + data := &models.LCPModels{} + if err := json.NewDecoder(resp.Body).Decode(data); err != nil { + return nil, err + } + return data, nil +} - // Capture and log the request body for debugging - if _, ok := body.(*io.LimitedReader); ok { - // If it's a LimitedReader, we need to handle it differently - logger.Debug("request body type is LimitedReader", "parser", chunkParser, "link", cfg.CurrentAPI) - req, err = http.NewRequest("POST", cfg.CurrentAPI, body) - if err != nil { - logger.Error("newreq error", "error", err) - if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { - logger.Error("failed to notify", "error", err) - } - streamDone <- true - return +// isModelLoaded checks if the given model ID is currently loaded in llama.cpp server. +func isModelLoaded(modelID string) (bool, error) { + models, err := fetchLCPModelsWithStatus() + if err != nil { + return false, err + } + for _, m := range models.Data { + if m.ID == modelID { + return m.Status.Value == "loaded", nil } - req.Header.Add("Accept", "application/json") - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Accept-Encoding", "gzip") - } else { - // For other reader types, capture and log the body content - bodyBytes, err := io.ReadAll(body) - if err != nil { - logger.Error("failed to read request body for logging", "error", err) - // Create request with original body if reading fails - req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes)) - if err != nil { - logger.Error("newreq error", "error", err) - if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { - logger.Error("failed to notify", "error", err) - } - streamDone <- true + } + return false, nil +} + +// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded. +func monitorModelLoad(modelID string) { + go func() { + timeout := time.After(2 * time.Minute) // max wait 2 minutes + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-timeout: + logger.Debug("model load monitoring timeout", "model", modelID) return - } - } else { - // Log the request body for debugging - logger.Debug("sending request to API", "api", cfg.CurrentAPI, "body", string(bodyBytes)) - // Create request with the captured body - req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes)) - if err != nil { - logger.Error("newreq error", "error", err) - if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { - logger.Error("failed to notify", "error", err) + case <-ticker.C: + loaded, err := isModelLoaded(modelID) + if err != nil { + logger.Debug("failed to check model status", "model", modelID, "error", err) + continue + } + if loaded { + if err := notifyUser("model loaded", "Model "+modelID+" is now loaded and ready."); err != nil { + logger.Debug("failed to notify user", "error", err) + } + return } - streamDone <- true - return } } + }() +} - req.Header.Add("Accept", "application/json") - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) - req.Header.Set("Accept-Encoding", "gzip") +// sendMsgToLLM expects streaming resp +func sendMsgToLLM(body io.Reader) { + choseChunkParser() + req, err := http.NewRequest("POST", cfg.CurrentAPI, body) + if err != nil { + logger.Error("newreq error", "error", err) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return } + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken()) + req.Header.Set("Accept-Encoding", "gzip") // nolint resp, err := httpClient.Do(req) if err != nil { @@ -396,6 +494,7 @@ func sendMsgToLLM(body io.Reader) { streamDone <- true break } + // // problem: this catches any mention of the word 'error' // Handle error messages in response content // example needed, since llm could use the word error in the normal msg // if string(line) != "" && strings.Contains(strings.ToLower(string(line)), "error") { @@ -422,7 +521,7 @@ func sendMsgToLLM(body io.Reader) { if chunk.FuncName != "" { lastToolCall.Name = chunk.FuncName // Store the tool call ID for the response - lastToolCallID = chunk.ToolID + lastToolCall.ID = chunk.ToolID } interrupt: if interruptResp { // read bytes, so it would not get into beginning of the next req @@ -604,20 +703,16 @@ out: Role: botPersona, Content: respText.String(), }) } - logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages)) for i, msg := range chatBody.Messages { logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) } - // // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template cleanChatBody() - logger.Debug("chatRound: after cleanChatBody", "messages_after_clean", len(chatBody.Messages)) for i, msg := range chatBody.Messages { logger.Debug("chatRound: after cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) } - colorText() updateStatusLine() // bot msg is done; @@ -631,20 +726,84 @@ out: // cleanChatBody removes messages with null or empty content to prevent API issues func cleanChatBody() { - if chatBody != nil && chatBody.Messages != nil { - originalLen := len(chatBody.Messages) - logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen) - for i, msg := range chatBody.Messages { - logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) - } + if chatBody == nil || chatBody.Messages == nil { + return + } + originalLen := len(chatBody.Messages) + logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen) + for i, msg := range chatBody.Messages { + logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } + // TODO: consider case where we keep tool requests + // /completion msg where part meant for user and other part tool call + chatBody.Messages = cleanToolCalls(chatBody.Messages) + chatBody.Messages = cleanNullMessages(chatBody.Messages) + logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages)) + for i, msg := range chatBody.Messages { + logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) + } +} - chatBody.Messages = cleanNullMessages(chatBody.Messages) +// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings. +func convertJSONToMapStringString(jsonStr string) (map[string]string, error) { + var raw map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil { + return nil, err + } + result := make(map[string]string, len(raw)) + for k, v := range raw { + switch val := v.(type) { + case string: + result[k] = val + case float64: + result[k] = strconv.FormatFloat(val, 'f', -1, 64) + case int, int64, int32: + // json.Unmarshal converts numbers to float64, but handle other integer types if they appear + result[k] = fmt.Sprintf("%v", val) + case bool: + result[k] = strconv.FormatBool(val) + case nil: + result[k] = "" + default: + result[k] = fmt.Sprintf("%v", val) + } + } + return result, nil +} - logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages)) - for i, msg := range chatBody.Messages { - logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID) +// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings. +func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) { + type tempFuncCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name"` + Args map[string]interface{} `json:"args"` + } + var temp tempFuncCall + if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil { + return nil, err + } + fc := &models.FuncCall{ + ID: temp.ID, + Name: temp.Name, + Args: make(map[string]string, len(temp.Args)), + } + for k, v := range temp.Args { + switch val := v.(type) { + case string: + fc.Args[k] = val + case float64: + fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64) + case int, int64, int32: + fc.Args[k] = fmt.Sprintf("%v", val) + case bool: + fc.Args[k] = strconv.FormatBool(val) + case nil: + fc.Args[k] = "" + default: + fc.Args[k] = fmt.Sprintf("%v", val) } } + return fc, nil } func findCall(msg, toolCall string, tv *tview.TextView) { @@ -652,30 +811,28 @@ func findCall(msg, toolCall string, tv *tview.TextView) { if toolCall != "" { // HTML-decode the tool call string to handle encoded characters like < -> <= decodedToolCall := html.UnescapeString(toolCall) - openAIToolMap := make(map[string]string) - // respect tool call - if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil { + openAIToolMap, err := convertJSONToMapStringString(decodedToolCall) + if err != nil { logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err) + // Ensure lastToolCall.ID is set for the error response (already set from chunk) // Send error response to LLM so it can retry or handle the error toolResponseMsg := models.RoleMsg{ Role: cfg.ToolRole, Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err), - ToolCallID: lastToolCallID, // Use the stored tool call ID + ToolCallID: lastToolCall.ID, // Use the stored tool call ID } chatBody.Messages = append(chatBody.Messages, toolResponseMsg) - // Clear the stored tool call ID after using it - lastToolCallID = "" + // Clear the stored tool call ID after using it (no longer needed) // Trigger the assistant to continue processing with the error message chatRound("", cfg.AssistantRole, tv, false, false) return } lastToolCall.Args = openAIToolMap fc = lastToolCall - // Ensure lastToolCallID is set if it's available in the tool call - if lastToolCallID == "" && len(openAIToolMap) > 0 { - // Attempt to extract ID from the parsed tool call if not already set + // Set lastToolCall.ID from parsed tool call ID if available + if len(openAIToolMap) > 0 { if id, exists := openAIToolMap["id"]; exists { - lastToolCallID = id + lastToolCall.ID = id } } } else { @@ -688,7 +845,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) { jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix) // HTML-decode the JSON string to handle encoded characters like < -> <= decodedJsStr := html.UnescapeString(jsStr) - if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil { + var err error + fc, err = unmarshalFuncCall(decodedJsStr) + if err != nil { logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr) // Send error response to LLM so it can retry or handle the error toolResponseMsg := models.RoleMsg{ @@ -701,28 +860,40 @@ func findCall(msg, toolCall string, tv *tview.TextView) { chatRound("", cfg.AssistantRole, tv, false, false) return } + // Update lastToolCall with parsed function call + lastToolCall.ID = fc.ID + lastToolCall.Name = fc.Name + lastToolCall.Args = fc.Args + } + // we got here => last msg recognized as a tool call (correct or not) + // make sure it has ToolCallID + if chatBody.Messages[len(chatBody.Messages)-1].ToolCallID == "" { + chatBody.Messages[len(chatBody.Messages)-1].ToolCallID = randString(6) + } + // Ensure lastToolCall.ID is set, fallback to assistant message's ToolCallID + if lastToolCall.ID == "" { + lastToolCall.ID = chatBody.Messages[len(chatBody.Messages)-1].ToolCallID } // call a func - f, ok := fnMap[fc.Name] + _, ok := fnMap[fc.Name] if !ok { m := fc.Name + " is not implemented" // Create tool response message with the proper tool_call_id toolResponseMsg := models.RoleMsg{ Role: cfg.ToolRole, Content: m, - ToolCallID: lastToolCallID, // Use the stored tool call ID + ToolCallID: lastToolCall.ID, // Use the stored tool call ID } chatBody.Messages = append(chatBody.Messages, toolResponseMsg) logger.Debug("findCall: added tool not implemented response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) // Clear the stored tool call ID after using it - lastToolCallID = "" - + lastToolCall.ID = "" // Trigger the assistant to continue processing with the new tool response // by calling chatRound with empty content to continue the assistant's response chatRound("", cfg.AssistantRole, tv, false, false) return } - resp := f(fc.Args) + resp := callToolWithAgent(fc.Name, fc.Args) toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc) fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n", @@ -731,12 +902,12 @@ func findCall(msg, toolCall string, tv *tview.TextView) { toolResponseMsg := models.RoleMsg{ Role: cfg.ToolRole, Content: toolMsg, - ToolCallID: lastToolCallID, // Use the stored tool call ID + ToolCallID: lastToolCall.ID, // Use the stored tool call ID } chatBody.Messages = append(chatBody.Messages, toolResponseMsg) logger.Debug("findCall: added actual tool response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages)) // Clear the stored tool call ID after using it - lastToolCallID = "" + lastToolCall.ID = "" // Trigger the assistant to continue processing with the new tool response // by calling chatRound with empty content to continue the assistant's response chatRound("", cfg.AssistantRole, tv, false, false) @@ -756,7 +927,7 @@ func chatToTextSlice(showSys bool) []string { func chatToText(showSys bool) string { s := chatToTextSlice(showSys) - return strings.Join(s, "") + return strings.Join(s, "\n") } func removeThinking(chatBody *models.ChatBody) { @@ -835,19 +1006,30 @@ func updateModelLists() { } } // if llama.cpp started after gf-lt? + localModelsMu.Lock() LocalModels, err = fetchLCPModels() + localModelsMu.Unlock() if err != nil { logger.Warn("failed to fetch llama.cpp models", "error", err) } } -func updateModelListsTicker() { - updateModelLists() // run on the start - ticker := time.NewTicker(time.Minute * 1) - for { - <-ticker.C - updateModelLists() +func refreshLocalModelsIfEmpty() { + localModelsMu.RLock() + if len(LocalModels) > 0 { + localModelsMu.RUnlock() + return + } + localModelsMu.RUnlock() + // try to fetch + models, err := fetchLCPModels() + if err != nil { + logger.Warn("failed to fetch llama.cpp models", "error", err) + return } + localModelsMu.Lock() + LocalModels = models + localModelsMu.Unlock() } func init() { @@ -903,12 +1085,12 @@ func init() { cluedoState = extra.CluedoPrepCards(playerOrder) } choseChunkParser() - httpClient = createClient(time.Second * 15) + httpClient = createClient(time.Second * 90) if cfg.TTS_ENABLED { orator = extra.NewOrator(logger, cfg) } if cfg.STT_ENABLED { asr = extra.NewSTT(logger, cfg) } - go updateModelListsTicker() + go updateModelLists() } |
