summaryrefslogtreecommitdiff
path: root/bot.go
diff options
context:
space:
mode:
authorGrail Finder <wohilas@gmail.com>2025-12-17 13:03:40 +0300
committerGrail Finder <wohilas@gmail.com>2025-12-17 13:03:40 +0300
commitd73c3abd6bda8690e8b5e57342221c8cb2cc88b3 (patch)
tree4715cec24efc6e39bf7eb46d13fc1cbf60ff2e29 /bot.go
parent35851647a191f779943530591610a9b22ffaeff9 (diff)
Feat: preload lcp model
Diffstat (limited to 'bot.go')
-rw-r--r--bot.go55
1 files changed, 54 insertions, 1 deletions
diff --git a/bot.go b/bot.go
index e2f03b8..8a0ba0a 100644
--- a/bot.go
+++ b/bot.go
@@ -16,6 +16,7 @@ import (
"log/slog"
"net"
"net/http"
+ "net/url"
"os"
"path"
"strings"
@@ -188,6 +189,58 @@ func createClient(connectTimeout time.Duration) *http.Client {
}
}
+func warmUpModel() {
+ u, err := url.Parse(cfg.CurrentAPI)
+ if err != nil {
+ return
+ }
+ host := u.Hostname()
+ if host != "localhost" && host != "127.0.0.1" && host != "::1" {
+ return
+ }
+ go func() {
+ var data []byte
+ var err error
+ if strings.HasSuffix(cfg.CurrentAPI, "/completion") {
+ // Old completion endpoint
+ req := models.NewLCPReq(".", chatBody.Model, nil, map[string]float32{
+ "temperature": 0.8,
+ "dry_multiplier": 0.0,
+ "min_p": 0.05,
+ "n_predict": 0,
+ }, []string{})
+ req.Stream = false
+ data, err = json.Marshal(req)
+ } else if strings.Contains(cfg.CurrentAPI, "/v1/chat/completions") {
+ // OpenAI-compatible chat endpoint
+ req := models.OpenAIReq{
+ ChatBody: &models.ChatBody{
+ Model: chatBody.Model,
+ Messages: []models.RoleMsg{
+ {Role: "system", Content: "."},
+ },
+ Stream: false,
+ },
+ Tools: nil,
+ }
+ data, err = json.Marshal(req)
+ } else {
+ // Unknown local endpoint, skip
+ return
+ }
+ if err != nil {
+ logger.Debug("failed to marshal warmup request", "error", err)
+ return
+ }
+ resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", bytes.NewReader(data))
+ if err != nil {
+ logger.Debug("warmup request failed", "error", err)
+ return
+ }
+ resp.Body.Close()
+ }()
+}
+
func fetchLCPModelName() *models.LCPModels {
//nolint
resp, err := httpClient.Get(cfg.FetchModelNameAPI)
@@ -894,7 +947,7 @@ func init() {
cluedoState = extra.CluedoPrepCards(playerOrder)
}
choseChunkParser()
- httpClient = createClient(time.Second * 15)
+ httpClient = createClient(time.Second * 90)
if cfg.TTS_ENABLED {
orator = extra.NewOrator(logger, cfg)
}