summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bot.go93
-rw-r--r--config/config.go29
-rw-r--r--llm.go90
-rw-r--r--models/models.go130
-rw-r--r--tui.go10
5 files changed, 329 insertions, 23 deletions
diff --git a/bot.go b/bot.go
index 7090e37..ec59db3 100644
--- a/bot.go
+++ b/bot.go
@@ -2,6 +2,8 @@ package main
import (
"bufio"
+ "bytes"
+ "context"
"elefant/config"
"elefant/models"
"elefant/rag"
@@ -10,6 +12,7 @@ import (
"fmt"
"io"
"log/slog"
+ "net"
"net/http"
"os"
"path"
@@ -20,7 +23,30 @@ import (
"github.com/rivo/tview"
)
-var httpClient = http.Client{}
+var httpClient = &http.Client{}
+
+func createClient(connectTimeout time.Duration) *http.Client {
+ // Custom transport with connection timeout
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ // Create a dialer with connection timeout
+ dialer := &net.Dialer{
+ Timeout: connectTimeout,
+ KeepAlive: 30 * time.Second, // Optional
+ }
+ return dialer.DialContext(ctx, network, addr)
+ },
+ // Other transport settings (optional)
+ TLSHandshakeTimeout: connectTimeout,
+ ResponseHeaderTimeout: connectTimeout,
+ }
+
+ // Client with no overall timeout (or set to streaming-safe duration)
+ return &http.Client{
+ Transport: transport,
+ Timeout: 0, // No overall timeout (for streaming)
+ }
+}
var (
cfg *config.Config
@@ -36,7 +62,6 @@ var (
defaultStarterBytes = []byte{}
interruptResp = false
ragger *rag.RAG
- currentModel = "none"
chunkParser ChunkParser
defaultLCPProps = map[string]float32{
"temperature": 0.8,
@@ -47,6 +72,7 @@ var (
)
func fetchModelName() *models.LLMModels {
+ // TODO: to config
api := "http://localhost:8080/v1/models"
//nolint
resp, err := httpClient.Get(api)
@@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels {
return nil
}
if resp.StatusCode != 200 {
- currentModel = "disconnected"
+ chatBody.Model = "disconnected"
return nil
}
- currentModel = path.Base(llmModel.Data[0].ID)
+ chatBody.Model = path.Base(llmModel.Data[0].ID)
return &llmModel
}
+func fetchDSBalance() *models.DSBalance {
+ url := "https://api.deepseek.com/user/balance"
+ method := "GET"
+ req, err := http.NewRequest(method, url, nil)
+ if err != nil {
+ logger.Warn("failed to create request", "error", err)
+ return nil
+ }
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
+ res, err := httpClient.Do(req)
+ if err != nil {
+ logger.Warn("failed to make request", "error", err)
+ return nil
+ }
+ defer res.Body.Close()
+ resp := models.DSBalance{}
+ if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
+ return nil
+ }
+ return &resp
+}
+
func sendMsgToLLM(body io.Reader) {
+ choseChunkParser()
+ req, err := http.NewRequest("POST", cfg.CurrentAPI, body)
+ req.Header.Add("Accept", "application/json")
+ req.Header.Add("Content-Type", "application/json")
+ req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken)
// nolint
- resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
+ // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body)
if err != nil {
logger.Error("llamacpp api", "error", err)
if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
@@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) {
streamDone <- true
return
}
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ bodyBytes, _ := io.ReadAll(body)
+ logger.Error("llamacpp api", "error", err, "body", string(bodyBytes))
+ if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil {
+ logger.Error("failed to notify", "error", err)
+ }
+ streamDone <- true
+ return
+ }
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
counter := uint32(0)
@@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) {
// starts with -> data:
line = line[6:]
logger.Debug("debugging resp", "line", string(line))
+ if bytes.Equal(line, []byte("[DONE]\n")) {
+ streamDone <- true
+ break
+ }
content, stop, err = chunkParser.ParseChunk(line)
if err != nil {
logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI)
@@ -185,7 +253,17 @@ func roleToIcon(role string) string {
func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) {
botRespMode = true
- // reader := formMsg(chatBody, userMsg, role)
+ defer func() { botRespMode = false }()
+ // check that there is a model set to use if is not local
+ if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI {
+ if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" {
+ if err := notifyUser("bad request", "wrong deepseek model name"); err != nil {
+ logger.Warn("failed ot notify user", "error", err)
+ return
+ }
+ return
+ }
+ }
reader, err := chunkParser.FormMsg(userMsg, role, resume)
if reader == nil || err != nil {
logger.Error("empty reader from msgs", "role", role, "error", err)
@@ -369,7 +447,8 @@ func init() {
Stream: true,
Messages: lastChat,
}
- initChunkParser()
+ choseChunkParser()
+ httpClient = createClient(time.Second * 15)
// go runModelNameTicker(time.Second * 120)
// tempLoad()
}
diff --git a/config/config.go b/config/config.go
index 63495b5..026fdab 100644
--- a/config/config.go
+++ b/config/config.go
@@ -7,10 +7,11 @@ import (
)
type Config struct {
- ChatAPI string `toml:"ChatAPI"`
- CompletionAPI string `toml:"CompletionAPI"`
- CurrentAPI string
- APIMap map[string]string
+ ChatAPI string `toml:"ChatAPI"`
+ CompletionAPI string `toml:"CompletionAPI"`
+ CurrentAPI string
+ CurrentProvider string
+ APIMap map[string]string
//
ShowSys bool `toml:"ShowSys"`
LogFile string `toml:"LogFile"`
@@ -30,6 +31,12 @@ type Config struct {
RAGWorkers uint32 `toml:"RAGWorkers"`
RAGBatchSize int `toml:"RAGBatchSize"`
RAGWordLimit uint32 `toml:"RAGWordLimit"`
+ // deepseek
+ DeepSeekChatAPI string `toml:"DeepSeekChatAPI"`
+ DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"`
+ DeepSeekToken string `toml:"DeepSeekToken"`
+ DeepSeekModel string `toml:"DeepSeekModel"`
+ ApiLinks []string
}
func LoadConfigOrDefault(fn string) *Config {
@@ -39,9 +46,11 @@ func LoadConfigOrDefault(fn string) *Config {
config := &Config{}
_, err := toml.DecodeFile(fn, &config)
if err != nil {
- fmt.Println("failed to read config from file, loading default")
+ fmt.Println("failed to read config from file, loading default", "error", err)
config.ChatAPI = "http://localhost:8080/v1/chat/completions"
config.CompletionAPI = "http://localhost:8080/completion"
+ config.DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions"
+ config.DeepSeekChatAPI = "https://api.deepseek.com/chat/completions"
config.RAGEnabled = false
config.EmbedURL = "http://localhost:8080/v1/embiddings"
config.ShowSys = true
@@ -58,12 +67,16 @@ func LoadConfigOrDefault(fn string) *Config {
}
config.CurrentAPI = config.ChatAPI
config.APIMap = map[string]string{
- config.ChatAPI: config.CompletionAPI,
+ config.ChatAPI: config.CompletionAPI,
+ config.DeepSeekChatAPI: config.DeepSeekCompletionAPI,
}
if config.CompletionAPI != "" {
config.CurrentAPI = config.CompletionAPI
- config.APIMap = map[string]string{
- config.CompletionAPI: config.ChatAPI,
+ config.APIMap[config.CompletionAPI] = config.ChatAPI
+ }
+ for _, el := range []string{config.ChatAPI, config.CompletionAPI, config.DeepSeekChatAPI, config.DeepSeekCompletionAPI} {
+ if el != "" {
+ config.ApiLinks = append(config.ApiLinks, el)
}
}
// if any value is empty fill with default
diff --git a/llm.go b/llm.go
index f328038..b62411d 100644
--- a/llm.go
+++ b/llm.go
@@ -13,20 +13,32 @@ type ChunkParser interface {
FormMsg(msg, role string, cont bool) (io.Reader, error)
}
-func initChunkParser() {
+func choseChunkParser() {
chunkParser = LlamaCPPeer{}
- if strings.Contains(cfg.CurrentAPI, "v1") {
- logger.Debug("chosen /v1/chat parser")
+ switch cfg.CurrentAPI {
+ case "http://localhost:8080/completion":
+ chunkParser = LlamaCPPeer{}
+ case "http://localhost:8080/v1/chat/completions":
chunkParser = OpenAIer{}
- return
+ case "https://api.deepseek.com/beta/completions":
+ chunkParser = DeepSeeker{}
+ default:
+ chunkParser = LlamaCPPeer{}
}
- logger.Debug("chosen llamacpp /completion parser")
+ // if strings.Contains(cfg.CurrentAPI, "chat") {
+ // logger.Debug("chosen chat parser")
+ // chunkParser = OpenAIer{}
+ // return
+ // }
+ // logger.Debug("chosen llamacpp /completion parser")
}
type LlamaCPPeer struct {
}
type OpenAIer struct {
}
+type DeepSeeker struct {
+}
func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
if msg != "" { // otherwise let the bot to continue
@@ -62,7 +74,12 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
}
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
- payload := models.NewLCPReq(prompt, cfg, defaultLCPProps)
+ var payload any
+ payload = models.NewLCPReq(prompt, cfg, defaultLCPProps)
+ if strings.Contains(chatBody.Model, "deepseek") {
+ payload = models.NewDSCompletionReq(prompt, chatBody.Model,
+ defaultLCPProps["temp"], cfg)
+ }
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
@@ -129,3 +146,64 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) {
}
return bytes.NewReader(data), nil
}
+
+// deepseek
+func (ds DeepSeeker) ParseChunk(data []byte) (string, bool, error) {
+ llmchunk := models.DSCompletionResp{}
+ if err := json.Unmarshal(data, &llmchunk); err != nil {
+ logger.Error("failed to decode", "error", err, "line", string(data))
+ return "", false, err
+ }
+ if llmchunk.Choices[0].FinishReason != "" {
+ if llmchunk.Choices[0].Text != "" {
+ logger.Error("text inside of finish llmchunk", "chunk", llmchunk)
+ }
+ return llmchunk.Choices[0].Text, true, nil
+ }
+ return llmchunk.Choices[0].Text, false, nil
+}
+
+func (ds DeepSeeker) FormMsg(msg, role string, resume bool) (io.Reader, error) {
+ if msg != "" { // otherwise let the bot to continue
+ newMsg := models.RoleMsg{Role: role, Content: msg}
+ chatBody.Messages = append(chatBody.Messages, newMsg)
+ // if rag
+ if cfg.RAGEnabled {
+ ragResp, err := chatRagUse(newMsg.Content)
+ if err != nil {
+ logger.Error("failed to form a rag msg", "error", err)
+ return nil, err
+ }
+ ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp}
+ chatBody.Messages = append(chatBody.Messages, ragMsg)
+ }
+ }
+ if cfg.ToolUse && !resume {
+ // add to chat body
+ chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg})
+ }
+ messages := make([]string, len(chatBody.Messages))
+ for i, m := range chatBody.Messages {
+ messages[i] = m.ToPrompt()
+ }
+ prompt := strings.Join(messages, "\n")
+ // strings builder?
+ if !resume {
+ botMsgStart := "\n" + cfg.AssistantRole + ":\n"
+ prompt += botMsgStart
+ }
+ if cfg.ThinkUse && !cfg.ToolUse {
+ prompt += "<think>"
+ }
+ logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
+ "msg", msg, "resume", resume, "prompt", prompt)
+ var payload any
+ payload = models.NewDSCompletionReq(prompt, chatBody.Model,
+ defaultLCPProps["temp"], cfg)
+ data, err := json.Marshal(payload)
+ if err != nil {
+ logger.Error("failed to form a msg", "error", err)
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
diff --git a/models/models.go b/models/models.go
index bb61abf..574be1c 100644
--- a/models/models.go
+++ b/models/models.go
@@ -103,6 +103,126 @@ type ChatToolsBody struct {
ToolChoice string `json:"tool_choice"`
}
+type DSChatReq struct {
+ Messages []RoleMsg `json:"messages"`
+ Model string `json:"model"`
+ Stream bool `json:"stream"`
+ FrequencyPenalty int `json:"frequency_penalty"`
+ MaxTokens int `json:"max_tokens"`
+ PresencePenalty int `json:"presence_penalty"`
+ Temperature float32 `json:"temperature"`
+ TopP float32 `json:"top_p"`
+ // ResponseFormat struct {
+ // Type string `json:"type"`
+ // } `json:"response_format"`
+ // Stop any `json:"stop"`
+ // StreamOptions any `json:"stream_options"`
+ // Tools any `json:"tools"`
+ // ToolChoice string `json:"tool_choice"`
+ // Logprobs bool `json:"logprobs"`
+ // TopLogprobs any `json:"top_logprobs"`
+}
+
+func NewDSCharReq(cb *ChatBody) DSChatReq {
+ return DSChatReq{
+ Messages: cb.Messages,
+ Model: cb.Model,
+ Stream: cb.Stream,
+ MaxTokens: 2048,
+ PresencePenalty: 0,
+ FrequencyPenalty: 0,
+ Temperature: 1.0,
+ TopP: 1.0,
+ }
+}
+
+type DSCompletionReq struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Echo bool `json:"echo"`
+ FrequencyPenalty int `json:"frequency_penalty"`
+ // Logprobs int `json:"logprobs"`
+ MaxTokens int `json:"max_tokens"`
+ PresencePenalty int `json:"presence_penalty"`
+ Stop any `json:"stop"`
+ Stream bool `json:"stream"`
+ StreamOptions any `json:"stream_options"`
+ Suffix any `json:"suffix"`
+ Temperature float32 `json:"temperature"`
+ TopP float32 `json:"top_p"`
+}
+
+func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
+ return DSCompletionReq{
+ Model: model,
+ Prompt: prompt,
+ Temperature: temp,
+ Stream: true,
+ Echo: false,
+ MaxTokens: 2048,
+ PresencePenalty: 0,
+ FrequencyPenalty: 0,
+ TopP: 1.0,
+ Stop: []string{
+ cfg.UserRole + ":\n", "<|im_end|>",
+ cfg.ToolRole + ":\n",
+ cfg.AssistantRole + ":\n",
+ },
+ }
+}
+
+type DSCompletionResp struct {
+ ID string `json:"id"`
+ Choices []struct {
+ FinishReason string `json:"finish_reason"`
+ Index int `json:"index"`
+ Logprobs struct {
+ TextOffset []int `json:"text_offset"`
+ TokenLogprobs []int `json:"token_logprobs"`
+ Tokens []string `json:"tokens"`
+ TopLogprobs []struct {
+ } `json:"top_logprobs"`
+ } `json:"logprobs"`
+ Text string `json:"text"`
+ } `json:"choices"`
+ Created int `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Object string `json:"object"`
+ Usage struct {
+ CompletionTokens int `json:"completion_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"`
+ PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ CompletionTokensDetails struct {
+ ReasoningTokens int `json:"reasoning_tokens"`
+ } `json:"completion_tokens_details"`
+ } `json:"usage"`
+}
+
+type DSChatResp struct {
+ Choices []struct {
+ Delta struct {
+ Content string `json:"content"`
+ Role any `json:"role"`
+ } `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+ Index int `json:"index"`
+ Logprobs any `json:"logprobs"`
+ } `json:"choices"`
+ Created int `json:"created"`
+ ID string `json:"id"`
+ Model string `json:"model"`
+ Object string `json:"object"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Usage struct {
+ CompletionTokens int `json:"completion_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ } `json:"usage"`
+}
+
type EmbeddingResp struct {
Embedding []float32 `json:"embedding"`
Index uint32 `json:"index"`
@@ -190,3 +310,13 @@ type LlamaCPPResp struct {
Content string `json:"content"`
Stop bool `json:"stop"`
}
+
+type DSBalance struct {
+ IsAvailable bool `json:"is_available"`
+ BalanceInfos []struct {
+ Currency string `json:"currency"`
+ TotalBalance string `json:"total_balance"`
+ GrantedBalance string `json:"granted_balance"`
+ ToppedUpBalance string `json:"topped_up_balance"`
+ } `json:"balance_infos"`
+}
diff --git a/tui.go b/tui.go
index 5e64e5e..30fc241 100644
--- a/tui.go
+++ b/tui.go
@@ -136,7 +136,7 @@ func colorText() {
}
func updateStatusLine() {
- position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
+ position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level()))
}
func initSysCards() ([]string, error) {
@@ -202,6 +202,12 @@ func makePropsForm(props map[string]float32) *tview.Form {
}).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1,
func(option string, optionIndex int) {
setLogLevel(option)
+ }).AddDropDown("Select an api: ", cfg.ApiLinks, 1,
+ func(option string, optionIndex int) {
+ cfg.CurrentAPI = option
+ }).AddDropDown("Select a model: ", []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"}, 0,
+ func(option string, optionIndex int) {
+ chatBody.Model = option
}).
AddButton("Quit", func() {
pages.RemovePage(propsPage)
@@ -600,7 +606,7 @@ func init() {
}
cfg.APIMap[newAPI] = prevAPI
cfg.CurrentAPI = newAPI
- initChunkParser()
+ choseChunkParser()
updateStatusLine()
return nil
}