diff options
-rw-r--r-- | bot.go | 93 | ||||
-rw-r--r-- | config/config.go | 29 | ||||
-rw-r--r-- | llm.go | 90 | ||||
-rw-r--r-- | models/models.go | 130 | ||||
-rw-r--r-- | tui.go | 10 |
5 files changed, 329 insertions, 23 deletions
@@ -2,6 +2,8 @@ package main import ( "bufio" + "bytes" + "context" "elefant/config" "elefant/models" "elefant/rag" @@ -10,6 +12,7 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" "os" "path" @@ -20,7 +23,30 @@ import ( "github.com/rivo/tview" ) -var httpClient = http.Client{} +var httpClient = &http.Client{} + +func createClient(connectTimeout time.Duration) *http.Client { + // Custom transport with connection timeout + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Create a dialer with connection timeout + dialer := &net.Dialer{ + Timeout: connectTimeout, + KeepAlive: 30 * time.Second, // Optional + } + return dialer.DialContext(ctx, network, addr) + }, + // Other transport settings (optional) + TLSHandshakeTimeout: connectTimeout, + ResponseHeaderTimeout: connectTimeout, + } + + // Client with no overall timeout (or set to streaming-safe duration) + return &http.Client{ + Transport: transport, + Timeout: 0, // No overall timeout (for streaming) + } +} var ( cfg *config.Config @@ -36,7 +62,6 @@ var ( defaultStarterBytes = []byte{} interruptResp = false ragger *rag.RAG - currentModel = "none" chunkParser ChunkParser defaultLCPProps = map[string]float32{ "temperature": 0.8, @@ -47,6 +72,7 @@ var ( ) func fetchModelName() *models.LLMModels { + // TODO: to config api := "http://localhost:8080/v1/models" //nolint resp, err := httpClient.Get(api) @@ -61,16 +87,44 @@ func fetchModelName() *models.LLMModels { return nil } if resp.StatusCode != 200 { - currentModel = "disconnected" + chatBody.Model = "disconnected" return nil } - currentModel = path.Base(llmModel.Data[0].ID) + chatBody.Model = path.Base(llmModel.Data[0].ID) return &llmModel } +func fetchDSBalance() *models.DSBalance { + url := "https://api.deepseek.com/user/balance" + method := "GET" + req, err := http.NewRequest(method, url, nil) + if err != nil { + logger.Warn("failed to create request", "error", err) + return nil + } + req.Header.Add("Accept", "application/json") + req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken) + res, err := httpClient.Do(req) + if err != nil { + logger.Warn("failed to make request", "error", err) + return nil + } + defer res.Body.Close() + resp := models.DSBalance{} + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + return nil + } + return &resp +} + func sendMsgToLLM(body io.Reader) { + choseChunkParser() + req, err := http.NewRequest("POST", cfg.CurrentAPI, body) + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", "Bearer "+cfg.DeepSeekToken) // nolint - resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) + // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) if err != nil { logger.Error("llamacpp api", "error", err) if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { @@ -79,6 +133,16 @@ func sendMsgToLLM(body io.Reader) { streamDone <- true return } + resp, err := httpClient.Do(req) + if err != nil { + bodyBytes, _ := io.ReadAll(body) + logger.Error("llamacpp api", "error", err, "body", string(bodyBytes)) + if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { + logger.Error("failed to notify", "error", err) + } + streamDone <- true + return + } defer resp.Body.Close() reader := bufio.NewReader(resp.Body) counter := uint32(0) @@ -113,6 +177,10 @@ func sendMsgToLLM(body io.Reader) { // starts with -> data: line = line[6:] logger.Debug("debugging resp", "line", string(line)) + if bytes.Equal(line, []byte("[DONE]\n")) { + streamDone <- true + break + } content, stop, err = chunkParser.ParseChunk(line) if err != nil { logger.Error("error parsing response body", "error", err, "line", string(line), "url", cfg.CurrentAPI) @@ -185,7 +253,17 @@ func roleToIcon(role string) string { func chatRound(userMsg, role string, tv *tview.TextView, regen, resume bool) { botRespMode = true - // reader := formMsg(chatBody, userMsg, role) + defer func() { botRespMode = false }() + // check that there is a model set to use if is not local + if cfg.CurrentAPI == cfg.DeepSeekChatAPI || cfg.CurrentAPI == cfg.DeepSeekCompletionAPI { + if chatBody.Model != "deepseek-chat" && chatBody.Model != "deepseek-reasoner" { + if err := notifyUser("bad request", "wrong deepseek model name"); err != nil { + logger.Warn("failed ot notify user", "error", err) + return + } + return + } + } reader, err := chunkParser.FormMsg(userMsg, role, resume) if reader == nil || err != nil { logger.Error("empty reader from msgs", "role", role, "error", err) @@ -369,7 +447,8 @@ func init() { Stream: true, Messages: lastChat, } - initChunkParser() + choseChunkParser() + httpClient = createClient(time.Second * 15) // go runModelNameTicker(time.Second * 120) // tempLoad() } diff --git a/config/config.go b/config/config.go index 63495b5..026fdab 100644 --- a/config/config.go +++ b/config/config.go @@ -7,10 +7,11 @@ import ( ) type Config struct { - ChatAPI string `toml:"ChatAPI"` - CompletionAPI string `toml:"CompletionAPI"` - CurrentAPI string - APIMap map[string]string + ChatAPI string `toml:"ChatAPI"` + CompletionAPI string `toml:"CompletionAPI"` + CurrentAPI string + CurrentProvider string + APIMap map[string]string // ShowSys bool `toml:"ShowSys"` LogFile string `toml:"LogFile"` @@ -30,6 +31,12 @@ type Config struct { RAGWorkers uint32 `toml:"RAGWorkers"` RAGBatchSize int `toml:"RAGBatchSize"` RAGWordLimit uint32 `toml:"RAGWordLimit"` + // deepseek + DeepSeekChatAPI string `toml:"DeepSeekChatAPI"` + DeepSeekCompletionAPI string `toml:"DeepSeekCompletionAPI"` + DeepSeekToken string `toml:"DeepSeekToken"` + DeepSeekModel string `toml:"DeepSeekModel"` + ApiLinks []string } func LoadConfigOrDefault(fn string) *Config { @@ -39,9 +46,11 @@ func LoadConfigOrDefault(fn string) *Config { config := &Config{} _, err := toml.DecodeFile(fn, &config) if err != nil { - fmt.Println("failed to read config from file, loading default") + fmt.Println("failed to read config from file, loading default", "error", err) config.ChatAPI = "http://localhost:8080/v1/chat/completions" config.CompletionAPI = "http://localhost:8080/completion" + config.DeepSeekCompletionAPI = "https://api.deepseek.com/beta/completions" + config.DeepSeekChatAPI = "https://api.deepseek.com/chat/completions" config.RAGEnabled = false config.EmbedURL = "http://localhost:8080/v1/embiddings" config.ShowSys = true @@ -58,12 +67,16 @@ func LoadConfigOrDefault(fn string) *Config { } config.CurrentAPI = config.ChatAPI config.APIMap = map[string]string{ - config.ChatAPI: config.CompletionAPI, + config.ChatAPI: config.CompletionAPI, + config.DeepSeekChatAPI: config.DeepSeekCompletionAPI, } if config.CompletionAPI != "" { config.CurrentAPI = config.CompletionAPI - config.APIMap = map[string]string{ - config.CompletionAPI: config.ChatAPI, + config.APIMap[config.CompletionAPI] = config.ChatAPI + } + for _, el := range []string{config.ChatAPI, config.CompletionAPI, config.DeepSeekChatAPI, config.DeepSeekCompletionAPI} { + if el != "" { + config.ApiLinks = append(config.ApiLinks, el) } } // if any value is empty fill with default @@ -13,20 +13,32 @@ type ChunkParser interface { FormMsg(msg, role string, cont bool) (io.Reader, error) } -func initChunkParser() { +func choseChunkParser() { chunkParser = LlamaCPPeer{} - if strings.Contains(cfg.CurrentAPI, "v1") { - logger.Debug("chosen /v1/chat parser") + switch cfg.CurrentAPI { + case "http://localhost:8080/completion": + chunkParser = LlamaCPPeer{} + case "http://localhost:8080/v1/chat/completions": chunkParser = OpenAIer{} - return + case "https://api.deepseek.com/beta/completions": + chunkParser = DeepSeeker{} + default: + chunkParser = LlamaCPPeer{} } - logger.Debug("chosen llamacpp /completion parser") + // if strings.Contains(cfg.CurrentAPI, "chat") { + // logger.Debug("chosen chat parser") + // chunkParser = OpenAIer{} + // return + // } + // logger.Debug("chosen llamacpp /completion parser") } type LlamaCPPeer struct { } type OpenAIer struct { } +type DeepSeeker struct { +} func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) { if msg != "" { // otherwise let the bot to continue @@ -62,7 +74,12 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) } logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, "msg", msg, "resume", resume, "prompt", prompt) - payload := models.NewLCPReq(prompt, cfg, defaultLCPProps) + var payload any + payload = models.NewLCPReq(prompt, cfg, defaultLCPProps) + if strings.Contains(chatBody.Model, "deepseek") { + payload = models.NewDSCompletionReq(prompt, chatBody.Model, + defaultLCPProps["temp"], cfg) + } data, err := json.Marshal(payload) if err != nil { logger.Error("failed to form a msg", "error", err) @@ -129,3 +146,64 @@ func (op OpenAIer) FormMsg(msg, role string, resume bool) (io.Reader, error) { } return bytes.NewReader(data), nil } + +// deepseek +func (ds DeepSeeker) ParseChunk(data []byte) (string, bool, error) { + llmchunk := models.DSCompletionResp{} + if err := json.Unmarshal(data, &llmchunk); err != nil { + logger.Error("failed to decode", "error", err, "line", string(data)) + return "", false, err + } + if llmchunk.Choices[0].FinishReason != "" { + if llmchunk.Choices[0].Text != "" { + logger.Error("text inside of finish llmchunk", "chunk", llmchunk) + } + return llmchunk.Choices[0].Text, true, nil + } + return llmchunk.Choices[0].Text, false, nil +} + +func (ds DeepSeeker) FormMsg(msg, role string, resume bool) (io.Reader, error) { + if msg != "" { // otherwise let the bot to continue + newMsg := models.RoleMsg{Role: role, Content: msg} + chatBody.Messages = append(chatBody.Messages, newMsg) + // if rag + if cfg.RAGEnabled { + ragResp, err := chatRagUse(newMsg.Content) + if err != nil { + logger.Error("failed to form a rag msg", "error", err) + return nil, err + } + ragMsg := models.RoleMsg{Role: cfg.ToolRole, Content: ragResp} + chatBody.Messages = append(chatBody.Messages, ragMsg) + } + } + if cfg.ToolUse && !resume { + // add to chat body + chatBody.Messages = append(chatBody.Messages, models.RoleMsg{Role: cfg.ToolRole, Content: toolSysMsg}) + } + messages := make([]string, len(chatBody.Messages)) + for i, m := range chatBody.Messages { + messages[i] = m.ToPrompt() + } + prompt := strings.Join(messages, "\n") + // strings builder? + if !resume { + botMsgStart := "\n" + cfg.AssistantRole + ":\n" + prompt += botMsgStart + } + if cfg.ThinkUse && !cfg.ToolUse { + prompt += "<think>" + } + logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, + "msg", msg, "resume", resume, "prompt", prompt) + var payload any + payload = models.NewDSCompletionReq(prompt, chatBody.Model, + defaultLCPProps["temp"], cfg) + data, err := json.Marshal(payload) + if err != nil { + logger.Error("failed to form a msg", "error", err) + return nil, err + } + return bytes.NewReader(data), nil +} diff --git a/models/models.go b/models/models.go index bb61abf..574be1c 100644 --- a/models/models.go +++ b/models/models.go @@ -103,6 +103,126 @@ type ChatToolsBody struct { ToolChoice string `json:"tool_choice"` } +type DSChatReq struct { + Messages []RoleMsg `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream"` + FrequencyPenalty int `json:"frequency_penalty"` + MaxTokens int `json:"max_tokens"` + PresencePenalty int `json:"presence_penalty"` + Temperature float32 `json:"temperature"` + TopP float32 `json:"top_p"` + // ResponseFormat struct { + // Type string `json:"type"` + // } `json:"response_format"` + // Stop any `json:"stop"` + // StreamOptions any `json:"stream_options"` + // Tools any `json:"tools"` + // ToolChoice string `json:"tool_choice"` + // Logprobs bool `json:"logprobs"` + // TopLogprobs any `json:"top_logprobs"` +} + +func NewDSCharReq(cb *ChatBody) DSChatReq { + return DSChatReq{ + Messages: cb.Messages, + Model: cb.Model, + Stream: cb.Stream, + MaxTokens: 2048, + PresencePenalty: 0, + FrequencyPenalty: 0, + Temperature: 1.0, + TopP: 1.0, + } +} + +type DSCompletionReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Echo bool `json:"echo"` + FrequencyPenalty int `json:"frequency_penalty"` + // Logprobs int `json:"logprobs"` + MaxTokens int `json:"max_tokens"` + PresencePenalty int `json:"presence_penalty"` + Stop any `json:"stop"` + Stream bool `json:"stream"` + StreamOptions any `json:"stream_options"` + Suffix any `json:"suffix"` + Temperature float32 `json:"temperature"` + TopP float32 `json:"top_p"` +} + +func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq { + return DSCompletionReq{ + Model: model, + Prompt: prompt, + Temperature: temp, + Stream: true, + Echo: false, + MaxTokens: 2048, + PresencePenalty: 0, + FrequencyPenalty: 0, + TopP: 1.0, + Stop: []string{ + cfg.UserRole + ":\n", "<|im_end|>", + cfg.ToolRole + ":\n", + cfg.AssistantRole + ":\n", + }, + } +} + +type DSCompletionResp struct { + ID string `json:"id"` + Choices []struct { + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Logprobs struct { + TextOffset []int `json:"text_offset"` + TokenLogprobs []int `json:"token_logprobs"` + Tokens []string `json:"tokens"` + TopLogprobs []struct { + } `json:"top_logprobs"` + } `json:"logprobs"` + Text string `json:"text"` + } `json:"choices"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Object string `json:"object"` + Usage struct { + CompletionTokens int `json:"completion_tokens"` + PromptTokens int `json:"prompt_tokens"` + PromptCacheHitTokens int `json:"prompt_cache_hit_tokens"` + PromptCacheMissTokens int `json:"prompt_cache_miss_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + } `json:"completion_tokens_details"` + } `json:"usage"` +} + +type DSChatResp struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + Role any `json:"role"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Logprobs any `json:"logprobs"` + } `json:"choices"` + Created int `json:"created"` + ID string `json:"id"` + Model string `json:"model"` + Object string `json:"object"` + SystemFingerprint string `json:"system_fingerprint"` + Usage struct { + CompletionTokens int `json:"completion_tokens"` + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + type EmbeddingResp struct { Embedding []float32 `json:"embedding"` Index uint32 `json:"index"` @@ -190,3 +310,13 @@ type LlamaCPPResp struct { Content string `json:"content"` Stop bool `json:"stop"` } + +type DSBalance struct { + IsAvailable bool `json:"is_available"` + BalanceInfos []struct { + Currency string `json:"currency"` + TotalBalance string `json:"total_balance"` + GrantedBalance string `json:"granted_balance"` + ToppedUpBalance string `json:"topped_up_balance"` + } `json:"balance_infos"` +} @@ -136,7 +136,7 @@ func colorText() { } func updateStatusLine() { - position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, currentModel, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level())) + position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level())) } func initSysCards() ([]string, error) { @@ -202,6 +202,12 @@ func makePropsForm(props map[string]float32) *tview.Form { }).AddDropDown("Set log level (Enter): ", []string{"Debug", "Info", "Warn"}, 1, func(option string, optionIndex int) { setLogLevel(option) + }).AddDropDown("Select an api: ", cfg.ApiLinks, 1, + func(option string, optionIndex int) { + cfg.CurrentAPI = option + }).AddDropDown("Select a model: ", []string{chatBody.Model, "deepseek-chat", "deepseek-reasoner"}, 0, + func(option string, optionIndex int) { + chatBody.Model = option }). AddButton("Quit", func() { pages.RemovePage(propsPage) @@ -600,7 +606,7 @@ func init() { } cfg.APIMap[newAPI] = prevAPI cfg.CurrentAPI = newAPI - initChunkParser() + choseChunkParser() updateStatusLine() return nil } |