diff options
Diffstat (limited to 'bot.go')
-rw-r--r-- | bot.go | 93 |
1 files changed, 86 insertions, 7 deletions
@@ -2,6 +2,8 @@ package main import ( "bufio" + "bytes" + "context" "elefant/config" "elefant/models" "elefant/rag" @@ -10,6 +12,7 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" "os" "path" @@ -20,7 +23,30 @@ import ( "github.com/rivo/tview" ) -var httpClient = http.Client{} +var httpClient = &http.Client{} + +func createClient(connectTimeout time.Duration) *http.Client { + // Custom transport with connection timeout + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Create a dialer with connection timeout + dialer := &net.Dialer{ + Timeout: connectTimeout, + KeepAlive: 30 * time.Second, // Optional + } + return dialer.DialContext(ctx, network, addr) + }, + // Other transport settings (optional) + TLSHandshakeTimeout: connectTimeout, + ResponseHeaderTimeout: connectTimeout, + } + + // Client with no overall timeout (or set to streaming-safe duration) + return &http.Client{ + Transport: transport, + Timeout: 0, // No overall timeout (for streaming) + } +} var ( cfg *config.Config @@ -36,7 +62,6 @@ var ( defaultStarterBytes = []byte{} interruptResp = false ragger *rag.RAG - currentModel = "none" chunkParser ChunkParser defaultLCPProps = map[string]float32{ "temperature": 0.8, @@ -47,6 +72,7 @@ var ( ) func fetchModelName() *models.LLMModels { + // TODO: to config api := "http://localhost:8080/v1/models" //nolint resp, err := httpClient.Get(api) @@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels { return nil } if resp.StatusCode != 200 { - currentModel = "disconnected" + chatBody.Model = "disconnected" return nil } - currentModel = path.Base(llmModel.Data[0].ID) + chatBody.Model = path.Base(llmModel.Data[0].ID) return &llmModel } +func fetchDSBalance() *models.DSBalance { + url := "https://api.deepseek.com/user/balance" + method := "GET" + req, err := http.NewRequest(method, url, nil) + if err != nil { + logger.Warn("failed to create request", "error", err) + return nil + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken) + res, err := httpClient.Do(req) + if err != nil { + logger.Warn("failed to make request", "error", err) + return nil + } + defer res.Body.Close() + resp := models.DSBalance{} + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + return nil + } + return &resp +} + func sendMsgToLLM(body io.Reader) { + choseChunkParser() + req, err := http.NewRequest("POST", cfg.CurrentAPI, body) + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken) // nolint - resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) + // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) if err != nil { logger.Error("llamacpp api", "error", err) if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { @@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) { streamDone <- true return } + resp, err := httpClient.Do(req) + if err != nil { + bodyBytes, _ := io.ReadAll(body) + logger.Error("llamacpp api", "error", err, "body", string(bodyBytes)) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return + } defer resp.Body.Close() reader := bufio.NewReader(resp.Body) counter := uint32(0) @@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) { // starts with -> data: line = line[6:] logger.Debug("debugging resp", "line", string(line)) + if bytes.Equal(line, []byte("[DONE]\n")) { + streamDone <- true + break + } content, stop, err = chunkParser.ParseChunk(line) if err != nil { logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI) @@ -185,7 +253,17 @@ func roleToIcon(role string) string { func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { botRespMode = true - // reader := formMsg(chatBody, userMsg, role) + defer func() { botRespMode = false }() + // check that there is a model set to use if is not local + if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI { + if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" { + if err := notifyUser("bad request", "wrong deepseek model name"); err != nil { + logger.Warn("failed ot notify user", "error", err) + return + } + return + } + } reader, err := chunkParser.FormMsg(userMsg, role, resume) if reader == nil || err != nil { logger.Error("empty reader from msgs", "role", role, "error", err) @@ -369,7 +447,8 @@ func init() { Stream: true, Messages: lastChat, } - initChunkParser() + choseChunkParser() + httpClient = createClient(time.Second * 15) // go runModelNameTicker(time.Second * 120) // tempLoad() } |