summaryrefslogtreecommitdiff
path: root/bot.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-12-21 11:39:36 +0300
committerGrail Finder <wohilas@gmail.com>2025-12-21 11:39:36 +0300
commit75fde2a575697f8f46ee9676c0ed228e5315a4e5 (patch)
tree64e02a6afef049eb2ca79a3a5d2b0beb8ba26385 /bot.go
parent1ca75a00642c4e0a6eea3117e3b4ebaacfdcfa7a (diff)
parent5525c946613a6f726cd116d79f1505a63ab25806 (diff)
Merge branch 'master' into doc/tutorial
Diffstat (limited to 'bot.go')
-rw-r--r--bot.go378
1 files changed, 280 insertions, 98 deletions
diff --git a/bot.go b/bot.go
index 3242b88..8ddcee5 100644
--- a/bot.go
+++ b/bot.go
@@ -16,9 +16,12 @@ import (
"log/slog"
"net"
"net/http"
+ "net/url"
"os"
"path"
+ "strconv"
"strings"
+ "sync"
"time"
"github.com/neurosnap/sentences/english"
@@ -47,10 +50,10 @@ var (
ragger *rag.RAG
chunkParser ChunkParser
lastToolCall *models.FuncCall
- lastToolCallID string // Store the ID of the most recent tool call
//nolint:unused // TTS_ENABLED conditionally uses this
orator extra.Orator
asr extra.STT
+ localModelsMu sync.RWMutex
defaultLCPProps = map[string]float32{
"temperature": 0.8,
"dry_multiplier": 0.0,
@@ -84,19 +87,31 @@ func cleanNullMessages(messages []models.RoleMsg) []models.RoleMsg {
return consolidateConsecutiveAssistantMessages(messages)
}
+func cleanToolCalls(messages []models.RoleMsg) []models.RoleMsg {
+ cleaned := make([]models.RoleMsg, 0, len(messages))
+ for i, msg := range messages {
+ // recognize the message as the tool call and remove it
+ if msg.ToolCallID == "" {
+ cleaned = append(cleaned, msg)
+ }
+ // tool call in last msg should stay
+ if i == len(messages)-1 {
+ cleaned = append(cleaned, msg)
+ }
+ }
+ return consolidateConsecutiveAssistantMessages(cleaned)
+}
+
// consolidateConsecutiveAssistantMessages merges consecutive assistant messages into a single message
func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models.RoleMsg {
if len(messages) == 0 {
return messages
}
-
consolidated := make([]models.RoleMsg, 0, len(messages))
currentAssistantMsg := models.RoleMsg{}
isBuildingAssistantMsg := false
-
for i := 0; i < len(messages); i++ {
msg := messages[i]
-
if msg.Role == cfg.AssistantRole || msg.Role == cfg.WriteNextMsgAsCompletionAgent {
// If this is an assistant message, start or continue building
if !isBuildingAssistantMsg {
@@ -141,12 +156,10 @@ func consolidateConsecutiveAssistantMessages(messages []models.RoleMsg) []models
consolidated = append(consolidated, msg)
}
}
-
// Don't forget the last assistant message if we were building one
if isBuildingAssistantMsg {
consolidated = append(consolidated, currentAssistantMsg)
}
-
return consolidated
}
@@ -188,6 +201,72 @@ func createClient(connectTimeout time.Duration) *http.Client {
}
}
+func warmUpModel() {
+ u, err := url.Parse(cfg.CurrentAPI)
+ if err != nil {
+ return
+ }
+ host := u.Hostname()
+ if host != "localhost" && host != "127.0.0.1" && host != "::1" {
+ return
+ }
+ // Check if model is already loaded
+ loaded, err := isModelLoaded(chatBody.Model)
+ if err != nil {
+ logger.Debug("failed to check model status", "model", chatBody.Model, "error", err)
+ // Continue with warmup attempt anyway
+ }
+ if loaded {
+ if err := notifyUser("model already loaded", "Model "+chatBody.Model+" is already loaded."); err != nil {
+ logger.Debug("failed to notify user", "error", err)
+ }
+ return
+ }
+ go func() {
+ var data []byte
+ var err error
+ if strings.HasSuffix(cfg.CurrentAPI, "/completion") {
+ // Old completion endpoint
+ req := models.NewLCPReq(".", chatBody.Model, nil, map[string]float32{
+ "temperature": 0.8,
+ "dry_multiplier": 0.0,
+ "min_p": 0.05,
+ "n_predict": 0,
+ }, []string{})
+ req.Stream = false
+ data, err = json.Marshal(req)
+ } else if strings.Contains(cfg.CurrentAPI, "/v1/chat/completions") {
+ // OpenAI-compatible chat endpoint
+ req := models.OpenAIReq{
+ ChatBody: &models.ChatBody{
+ Model: chatBody.Model,
+ Messages: []models.RoleMsg{
+ {Role: "system", Content: "."},
+ },
+ Stream: false,
+ },
+ Tools: nil,
+ }
+ data, err = json.Marshal(req)
+ } else {
+ // Unknown local endpoint, skip
+ return
+ }
+ if err != nil {
+ logger.Debug("failed to marshal warmup request", "error", err)
+ return
+ }
+ resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", bytes.NewReader(data))
+ if err != nil {
+ logger.Debug("warmup request failed", "error", err)
+ return
+ }
+ resp.Body.Close()
+ // Start monitoring for model load completion
+ monitorModelLoad(chatBody.Model)
+ }()
+}
+
func fetchLCPModelName() *models.LCPModels {
//nolint
resp, err := httpClient.Get(cfg.FetchModelNameAPI)
@@ -210,6 +289,7 @@ func fetchLCPModelName() *models.LCPModels {
return nil
}
chatBody.Model = path.Base(llmModel.Data[0].ID)
+ cfg.CurrentModel = chatBody.Model
return &llmModel
}
@@ -274,64 +354,82 @@ func fetchLCPModels() ([]string, error) {
return localModels, nil
}
-func sendMsgToLLM(body io.Reader) {
- choseChunkParser()
-
- var req *http.Request
- var err error
+// fetchLCPModelsWithStatus returns the full LCPModels struct including status information.
+func fetchLCPModelsWithStatus() (*models.LCPModels, error) {
+ resp, err := http.Get(cfg.FetchModelNameAPI)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ err := fmt.Errorf("failed to fetch llama.cpp models; status: %s", resp.Status)
+ return nil, err
+ }
+ data := &models.LCPModels{}
+ if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
+ return nil, err
+ }
+ return data, nil
+}
- // Capture and log the request body for debugging
- if _, ok := body.(*io.LimitedReader); ok {
- // If it's a LimitedReader, we need to handle it differently
- logger.Debug("request body type is LimitedReader", "parser", chunkParser, "link", cfg.CurrentAPI)
- req, err = http.NewRequest("POST", cfg.CurrentAPI, body)
- if err != nil {
- logger.Error("newreq error", "error", err)
- if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
- logger.Error("failed to notify", "error", err)
- }
- streamDone <- true
- return
+// isModelLoaded checks if the given model ID is currently loaded in llama.cpp server.
+func isModelLoaded(modelID string) (bool, error) {
+ models, err := fetchLCPModelsWithStatus()
+ if err != nil {
+ return false, err
+ }
+ for _, m := range models.Data {
+ if m.ID == modelID {
+ return m.Status.Value == "loaded", nil
}
- req.Header.Add("Accept", "application/json")
- req.Header.Add("Content-Type", "application/json")
- req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken())
- req.Header.Set("Accept-Encoding", "gzip")
- } else {
- // For other reader types, capture and log the body content
- bodyBytes, err := io.ReadAll(body)
- if err != nil {
- logger.Error("failed to read request body for logging", "error", err)
- // Create request with original body if reading fails
- req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes))
- if err != nil {
- logger.Error("newreq error", "error", err)
- if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
- logger.Error("failed to notify", "error", err)
- }
- streamDone <- true
+ }
+ return false, nil
+}
+
+// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded.
+func monitorModelLoad(modelID string) {
+ go func() {
+ timeout := time.After(2 * time.Minute) // max wait 2 minutes
+ ticker := time.NewTicker(2 * time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-timeout:
+ logger.Debug("model load monitoring timeout", "model", modelID)
return
- }
- } else {
- // Log the request body for debugging
- logger.Debug("sending request to API", "api", cfg.CurrentAPI, "body", string(bodyBytes))
- // Create request with the captured body
- req, err = http.NewRequest("POST", cfg.CurrentAPI, bytes.NewReader(bodyBytes))
- if err != nil {
- logger.Error("newreq error", "error", err)
- if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
- logger.Error("failed to notify", "error", err)
+ case <-ticker.C:
+ loaded, err := isModelLoaded(modelID)
+ if err != nil {
+ logger.Debug("failed to check model status", "model", modelID, "error", err)
+ continue
+ }
+ if loaded {
+ if err := notifyUser("model loaded", "Model "+modelID+" is now loaded and ready."); err != nil {
+ logger.Debug("failed to notify user", "error", err)
+ }
+ return
}
- streamDone <- true
- return
}
}
+ }()
+}
- req.Header.Add("Accept", "application/json")
- req.Header.Add("Content-Type", "application/json")
- req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken())
- req.Header.Set("Accept-Encoding", "gzip")
+// sendMsgToLLM expects streaming resp
+func sendMsgToLLM(body io.Reader) {
+ choseChunkParser()
+ req, err := http.NewRequest("POST", cfg.CurrentAPI, body)
+ if err != nil {
+ logger.Error("newreq error", "error", err)
+ if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
+ logger.Error("failed to notify", "error", err)
+ }
+ streamDone <- true
+ return
}
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Content-Type", "application/json")
+ req.Header.Add("Authorization", "Bearer "+chunkParser.GetToken())
+ req.Header.Set("Accept-Encoding", "gzip")
// nolint
resp, err := httpClient.Do(req)
if err != nil {
@@ -396,6 +494,7 @@ func sendMsgToLLM(body io.Reader) {
streamDone <- true
break
}
+ // // problem: this catches any mention of the word 'error'
// Handle error messages in response content
// example needed, since llm could use the word error in the normal msg
// if string(line) != "" && strings.Contains(strings.ToLower(string(line)), "error") {
@@ -422,7 +521,7 @@ func sendMsgToLLM(body io.Reader) {
if chunk.FuncName != "" {
lastToolCall.Name = chunk.FuncName
// Store the tool call ID for the response
- lastToolCallID = chunk.ToolID
+ lastToolCall.ID = chunk.ToolID
}
interrupt:
if interruptResp { // read bytes, so it would not get into beginning of the next req
@@ -604,20 +703,16 @@ out:
Role: botPersona, Content: respText.String(),
})
}
-
logger.Debug("chatRound: before cleanChatBody", "messages_before_clean", len(chatBody.Messages))
for i, msg := range chatBody.Messages {
logger.Debug("chatRound: before cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
}
-
// // Clean null/empty messages to prevent API issues with endpoints like llama.cpp jinja template
cleanChatBody()
-
logger.Debug("chatRound: after cleanChatBody", "messages_after_clean", len(chatBody.Messages))
for i, msg := range chatBody.Messages {
logger.Debug("chatRound: after cleaning", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
}
-
colorText()
updateStatusLine()
// bot msg is done;
@@ -631,20 +726,84 @@ out:
// cleanChatBody removes messages with null or empty content to prevent API issues
func cleanChatBody() {
- if chatBody != nil && chatBody.Messages != nil {
- originalLen := len(chatBody.Messages)
- logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen)
- for i, msg := range chatBody.Messages {
- logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
- }
+ if chatBody == nil || chatBody.Messages == nil {
+ return
+ }
+ originalLen := len(chatBody.Messages)
+ logger.Debug("cleanChatBody: before cleaning", "message_count", originalLen)
+ for i, msg := range chatBody.Messages {
+ logger.Debug("cleanChatBody: before clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
+ }
+ // TODO: consider case where we keep tool requests
+ // /completion msg where part meant for user and other part tool call
+ chatBody.Messages = cleanToolCalls(chatBody.Messages)
+ chatBody.Messages = cleanNullMessages(chatBody.Messages)
+ logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages))
+ for i, msg := range chatBody.Messages {
+ logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
+ }
+}
- chatBody.Messages = cleanNullMessages(chatBody.Messages)
+// convertJSONToMapStringString unmarshals JSON into map[string]interface{} and converts all values to strings.
+func convertJSONToMapStringString(jsonStr string) (map[string]string, error) {
+ var raw map[string]interface{}
+ if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
+ return nil, err
+ }
+ result := make(map[string]string, len(raw))
+ for k, v := range raw {
+ switch val := v.(type) {
+ case string:
+ result[k] = val
+ case float64:
+ result[k] = strconv.FormatFloat(val, 'f', -1, 64)
+ case int, int64, int32:
+ // json.Unmarshal converts numbers to float64, but handle other integer types if they appear
+ result[k] = fmt.Sprintf("%v", val)
+ case bool:
+ result[k] = strconv.FormatBool(val)
+ case nil:
+ result[k] = ""
+ default:
+ result[k] = fmt.Sprintf("%v", val)
+ }
+ }
+ return result, nil
+}
- logger.Debug("cleanChatBody: after cleaning", "original_len", originalLen, "new_len", len(chatBody.Messages))
- for i, msg := range chatBody.Messages {
- logger.Debug("cleanChatBody: after clean", "index", i, "role", msg.Role, "content_len", len(msg.Content), "has_content", msg.HasContent(), "tool_call_id", msg.ToolCallID)
+// unmarshalFuncCall unmarshals a JSON tool call, converting numeric arguments to strings.
+func unmarshalFuncCall(jsonStr string) (*models.FuncCall, error) {
+ type tempFuncCall struct {
+ ID string `json:"id,omitempty"`
+ Name string `json:"name"`
+ Args map[string]interface{} `json:"args"`
+ }
+ var temp tempFuncCall
+ if err := json.Unmarshal([]byte(jsonStr), &temp); err != nil {
+ return nil, err
+ }
+ fc := &models.FuncCall{
+ ID: temp.ID,
+ Name: temp.Name,
+ Args: make(map[string]string, len(temp.Args)),
+ }
+ for k, v := range temp.Args {
+ switch val := v.(type) {
+ case string:
+ fc.Args[k] = val
+ case float64:
+ fc.Args[k] = strconv.FormatFloat(val, 'f', -1, 64)
+ case int, int64, int32:
+ fc.Args[k] = fmt.Sprintf("%v", val)
+ case bool:
+ fc.Args[k] = strconv.FormatBool(val)
+ case nil:
+ fc.Args[k] = ""
+ default:
+ fc.Args[k] = fmt.Sprintf("%v", val)
}
}
+ return fc, nil
}
func findCall(msg, toolCall string, tv *tview.TextView) {
@@ -652,30 +811,28 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
if toolCall != "" {
// HTML-decode the tool call string to handle encoded characters like &lt; -> <=
decodedToolCall := html.UnescapeString(toolCall)
- openAIToolMap := make(map[string]string)
- // respect tool call
- if err := json.Unmarshal([]byte(decodedToolCall), &openAIToolMap); err != nil {
+ openAIToolMap, err := convertJSONToMapStringString(decodedToolCall)
+ if err != nil {
logger.Error("failed to unmarshal openai tool call", "call", decodedToolCall, "error", err)
+ // Ensure lastToolCall.ID is set for the error response (already set from chunk)
// Send error response to LLM so it can retry or handle the error
toolResponseMsg := models.RoleMsg{
Role: cfg.ToolRole,
Content: fmt.Sprintf("Error processing tool call: %v. Please check the JSON format and try again.", err),
- ToolCallID: lastToolCallID, // Use the stored tool call ID
+ ToolCallID: lastToolCall.ID, // Use the stored tool call ID
}
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
- // Clear the stored tool call ID after using it
- lastToolCallID = ""
+ // Clear the stored tool call ID after using it (no longer needed)
// Trigger the assistant to continue processing with the error message
chatRound("", cfg.AssistantRole, tv, false, false)
return
}
lastToolCall.Args = openAIToolMap
fc = lastToolCall
- // Ensure lastToolCallID is set if it's available in the tool call
- if lastToolCallID == "" && len(openAIToolMap) > 0 {
- // Attempt to extract ID from the parsed tool call if not already set
+ // Set lastToolCall.ID from parsed tool call ID if available
+ if len(openAIToolMap) > 0 {
if id, exists := openAIToolMap["id"]; exists {
- lastToolCallID = id
+ lastToolCall.ID = id
}
}
} else {
@@ -688,7 +845,9 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
jsStr = strings.TrimSuffix(strings.TrimPrefix(jsStr, prefix), suffix)
// HTML-decode the JSON string to handle encoded characters like &lt; -> <=
decodedJsStr := html.UnescapeString(jsStr)
- if err := json.Unmarshal([]byte(decodedJsStr), &fc); err != nil {
+ var err error
+ fc, err = unmarshalFuncCall(decodedJsStr)
+ if err != nil {
logger.Error("failed to unmarshal tool call", "error", err, "json_string", decodedJsStr)
// Send error response to LLM so it can retry or handle the error
toolResponseMsg := models.RoleMsg{
@@ -701,28 +860,40 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
chatRound("", cfg.AssistantRole, tv, false, false)
return
}
+ // Update lastToolCall with parsed function call
+ lastToolCall.ID = fc.ID
+ lastToolCall.Name = fc.Name
+ lastToolCall.Args = fc.Args
+ }
+ // we got here => last msg recognized as a tool call (correct or not)
+ // make sure it has ToolCallID
+ if chatBody.Messages[len(chatBody.Messages)-1].ToolCallID == "" {
+ chatBody.Messages[len(chatBody.Messages)-1].ToolCallID = randString(6)
+ }
+ // Ensure lastToolCall.ID is set, fallback to assistant message's ToolCallID
+ if lastToolCall.ID == "" {
+ lastToolCall.ID = chatBody.Messages[len(chatBody.Messages)-1].ToolCallID
}
// call a func
- f, ok := fnMap[fc.Name]
+ _, ok := fnMap[fc.Name]
if !ok {
m := fc.Name + " is not implemented"
// Create tool response message with the proper tool_call_id
toolResponseMsg := models.RoleMsg{
Role: cfg.ToolRole,
Content: m,
- ToolCallID: lastToolCallID, // Use the stored tool call ID
+ ToolCallID: lastToolCall.ID, // Use the stored tool call ID
}
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
logger.Debug("findCall: added tool not implemented response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages))
// Clear the stored tool call ID after using it
- lastToolCallID = ""
-
+ lastToolCall.ID = ""
// Trigger the assistant to continue processing with the new tool response
// by calling chatRound with empty content to continue the assistant's response
chatRound("", cfg.AssistantRole, tv, false, false)
return
}
- resp := f(fc.Args)
+ resp := callToolWithAgent(fc.Name, fc.Args)
toolMsg := string(resp) // Remove the "tool response: " prefix and %+v formatting
logger.Info("llm used tool call", "tool_resp", toolMsg, "tool_attrs", fc)
fmt.Fprintf(tv, "%s[-:-:b](%d) <%s>: [-:-:-]\n%s\n",
@@ -731,12 +902,12 @@ func findCall(msg, toolCall string, tv *tview.TextView) {
toolResponseMsg := models.RoleMsg{
Role: cfg.ToolRole,
Content: toolMsg,
- ToolCallID: lastToolCallID, // Use the stored tool call ID
+ ToolCallID: lastToolCall.ID, // Use the stored tool call ID
}
chatBody.Messages = append(chatBody.Messages, toolResponseMsg)
logger.Debug("findCall: added actual tool response", "role", toolResponseMsg.Role, "content_len", len(toolResponseMsg.Content), "tool_call_id", toolResponseMsg.ToolCallID, "message_count_after_add", len(chatBody.Messages))
// Clear the stored tool call ID after using it
- lastToolCallID = ""
+ lastToolCall.ID = ""
// Trigger the assistant to continue processing with the new tool response
// by calling chatRound with empty content to continue the assistant's response
chatRound("", cfg.AssistantRole, tv, false, false)
@@ -756,7 +927,7 @@ func chatToTextSlice(showSys bool) []string {
func chatToText(showSys bool) string {
s := chatToTextSlice(showSys)
- return strings.Join(s, "")
+ return strings.Join(s, "\n")
}
func removeThinking(chatBody *models.ChatBody) {
@@ -835,19 +1006,30 @@ func updateModelLists() {
}
}
// if llama.cpp started after gf-lt?
+ localModelsMu.Lock()
LocalModels, err = fetchLCPModels()
+ localModelsMu.Unlock()
if err != nil {
logger.Warn("failed to fetch llama.cpp models", "error", err)
}
}
-func updateModelListsTicker() {
- updateModelLists() // run on the start
- ticker := time.NewTicker(time.Minute * 1)
- for {
- <-ticker.C
- updateModelLists()
+func refreshLocalModelsIfEmpty() {
+ localModelsMu.RLock()
+ if len(LocalModels) > 0 {
+ localModelsMu.RUnlock()
+ return
+ }
+ localModelsMu.RUnlock()
+ // try to fetch
+ models, err := fetchLCPModels()
+ if err != nil {
+ logger.Warn("failed to fetch llama.cpp models", "error", err)
+ return
}
+ localModelsMu.Lock()
+ LocalModels = models
+ localModelsMu.Unlock()
}
func init() {
@@ -903,12 +1085,12 @@ func init() {
cluedoState = extra.CluedoPrepCards(playerOrder)
}
choseChunkParser()
- httpClient = createClient(time.Second * 15)
+ httpClient = createClient(time.Second * 90)
if cfg.TTS_ENABLED {
orator = extra.NewOrator(logger, cfg)
}
if cfg.STT_ENABLED {
asr = extra.NewSTT(logger, cfg)
}
- go updateModelListsTicker()
+ go updateModelLists()
}