summaryrefslogtreecommitdiff
path: root/bot.go
diff options
context:
space:
mode:
Diffstat (limited to 'bot.go')
-rw-r--r--bot.go74
1 files changed, 74 insertions, 0 deletions
diff --git a/bot.go b/bot.go
index 779278e..1603e0d 100644
--- a/bot.go
+++ b/bot.go
@@ -199,6 +199,18 @@ func warmUpModel() {
if host != "localhost" && host != "127.0.0.1" && host != "::1" {
return
}
+ // Check if model is already loaded
+ loaded, err := isModelLoaded(chatBody.Model)
+ if err != nil {
+ logger.Debug("failed to check model status", "model", chatBody.Model, "error", err)
+ // Continue with warmup attempt anyway
+ }
+ if loaded {
+ if err := notifyUser("model already loaded", "Model "+chatBody.Model+" is already loaded."); err != nil {
+ logger.Debug("failed to notify user", "error", err)
+ }
+ return
+ }
go func() {
var data []byte
var err error
@@ -239,6 +251,8 @@ func warmUpModel() {
return
}
resp.Body.Close()
+ // Start monitoring for model load completion
+ monitorModelLoad(chatBody.Model)
}()
}
@@ -329,6 +343,66 @@ func fetchLCPModels() ([]string, error) {
return localModels, nil
}
+// fetchLCPModelsWithStatus returns the full LCPModels struct including status information.
+func fetchLCPModelsWithStatus() (*models.LCPModels, error) {
+ resp, err := http.Get(cfg.FetchModelNameAPI)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ err := fmt.Errorf("failed to fetch llama.cpp models; status: %s", resp.Status)
+ return nil, err
+ }
+ data := &models.LCPModels{}
+ if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
+ return nil, err
+ }
+ return data, nil
+}
+
+// isModelLoaded checks if the given model ID is currently loaded in llama.cpp server.
+func isModelLoaded(modelID string) (bool, error) {
+ models, err := fetchLCPModelsWithStatus()
+ if err != nil {
+ return false, err
+ }
+ for _, m := range models.Data {
+ if m.ID == modelID {
+ return m.Status.Value == "loaded", nil
+ }
+ }
+ return false, nil
+}
+
+// monitorModelLoad starts a goroutine that periodically checks if the specified model is loaded.
+func monitorModelLoad(modelID string) {
+ go func() {
+ timeout := time.After(2 * time.Minute) // max wait 2 minutes
+ ticker := time.NewTicker(2 * time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-timeout:
+ logger.Debug("model load monitoring timeout", "model", modelID)
+ return
+ case <-ticker.C:
+ loaded, err := isModelLoaded(modelID)
+ if err != nil {
+ logger.Debug("failed to check model status", "model", modelID, "error", err)
+ continue
+ }
+ if loaded {
+ if err := notifyUser("model loaded", "Model "+modelID+" is now loaded and ready."); err != nil {
+ logger.Debug("failed to notify user", "error", err)
+ }
+ return
+ }
+ }
+ }
+ }()
+}
+
// sendMsgToLLM expects streaming resp
func sendMsgToLLM(body io.Reader) {
choseChunkParser()