summaryrefslogtreecommitdiff
path: root/bot.go
diff options
context:
space:
mode:
Diffstat (limited to 'bot.go')
-rw-r--r--bot.go93
1 files changed, 86 insertions, 7 deletions
diff --git a/bot.go b/bot.go
index 7090e37..ec59db3 100644
--- a/bot.go
+++ b/bot.go
@@ -2,6 +2,8 @@ package main
import (
"bufio"
+ "bytes"
+ "context"
"elefant/config"
"elefant/models"
"elefant/rag"
@@ -10,6 +12,7 @@ import (
"fmt"
"io"
"log/slog"
+ "net"
"net/http"
"os"
"path"
@@ -20,7 +23,30 @@ import (
"github.com/rivo/tview"
)
-var httpClient = http.Client{}
+var httpClient = &http.Client{}
+
+func createClient(connectTimeout time.Duration) *http.Client {
+ // Custom transport with connection timeout
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ // Create a dialer with connection timeout
+ dialer := &net.Dialer{
+ Timeout: connectTimeout,
+ KeepAlive: 30 * time.Second, // Optional
+ }
+ return dialer.DialContext(ctx, network, addr)
+ },
+ // Other transport settings (optional)
+ TLSHandshakeTimeout: connectTimeout,
+ ResponseHeaderTimeout: connectTimeout,
+ }
+
+ // Client with no overall timeout (or set to streaming-safe duration)
+ return &http.Client{
+ Transport: transport,
+ Timeout: 0, // No overall timeout (for streaming)
+ }
+}
var (
cfg *config.Config
@@ -36,7 +62,6 @@ var (
defaultStarterBytes = []byte{}
interruptResp = false
ragger *rag.RAG
- currentModel = "none"
chunkParser ChunkParser
defaultLCPProps = map[string]float32{
"temperature": 0.8,
@@ -47,6 +72,7 @@ var (
)
func fetchModelName() *models.LLMModels {
+ // TODO: to config
api := "http://localhost:8080/v1/models"
//nolint
resp, err := httpClient.Get(api)
@@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels {
return nil
}
if resp.StatusCode != 200 {
- currentModel = "disconnected"
+ chatBody.Model = "disconnected"
return nil
}
- currentModel = path.Base(llmModel.Data[0].ID)
+ chatBody.Model = path.Base(llmModel.Data[0].ID)
return &llmModel
}
+func fetchDSBalance() *models.DSBalance {
+ url := "https://api.deepseek.com/user/balance"
+ method := "GET"
+ req, err := http.NewRequest(method, url, nil)
+ if err != nil {
+ logger.Warn("failed to create request", "error", err)
+ return nil
+ }
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
+ res, err := httpClient.Do(req)
+ if err != nil {
+ logger.Warn("failed to make request", "error", err)
+ return nil
+ }
+ defer res.Body.Close()
+ resp := models.DSBalance{}
+ if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
+ return nil
+ }
+ return &resp
+}
+
func sendMsgToLLM(body io.Reader) {
+ choseChunkParser()
+ req, err := http.NewRequest("POST", cfg.CurrentAPI, body)
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Content-Type", "application/json")
+ req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
// nolint
- resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
+ // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
if err != nil {
logger.Error("llamacpp api", "error", err)
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
@@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) {
streamDone <- true
return
}
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ bodyBytes, _ := io.ReadAll(body)
+ logger.Error("llamacpp api", "error", err, "body", string(bodyBytes))
+ if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
+ logger.Error("failed to notify", "error", err)
+ }
+ streamDone <- true
+ return
+ }
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
counter := uint32(0)
@@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) {
// starts with -> data:
line = line[6:]
logger.Debug("debugging resp", "line", string(line))
+ if bytes.Equal(line, []byte("[DONE]\n")) {
+ streamDone <- true
+ break
+ }
content, stop, err = chunkParser.ParseChunk(line)
if err != nil {
logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI)
@@ -185,7 +253,17 @@ func roleToIcon(role string) string {
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
botRespMode = true
- // reader := formMsg(chatBody, userMsg, role)
+ defer func() { botRespMode = false }()
+ // check that there is a model set to use if is not local
+ if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
+ if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
+ if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
+ logger.Warn("failed ot notify user", "error", err)
+ return
+ }
+ return
+ }
+ }
reader, err := chunkParser.FormMsg(userMsg, role, resume)
if reader == nil || err != nil {
logger.Error("empty reader from msgs", "role", role, "error", err)
@@ -369,7 +447,8 @@ func init() {
Stream: true,
Messages: lastChat,
}
- initChunkParser()
+ choseChunkParser()
+ httpClient = createClient(time.Second * 15)
// go runModelNameTicker(time.Second * 120)
// tempLoad()
}