diff options
-rw-r--r-- | llm.go | 8 | ||||
-rw-r--r-- | models/models.go | 28 |
2 files changed, 20 insertions, 16 deletions
@@ -2,8 +2,8 @@ package main import ( "bytes" - "gf-lt/models" "encoding/json" + "gf-lt/models" "io" "strings" ) @@ -88,10 +88,10 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error) logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, "msg", msg, "resume", resume, "prompt", prompt) var payload any - payload = models.NewLCPReq(prompt, cfg, defaultLCPProps) + payload = models.NewLCPReq(prompt, cfg, defaultLCPProps, chatBody.MakeStopSlice()) if strings.Contains(chatBody.Model, "deepseek") { payload = models.NewDSCompletionReq(prompt, chatBody.Model, - defaultLCPProps["temp"], cfg) + defaultLCPProps["temp"], cfg, chatBody.MakeStopSlice()) } data, err := json.Marshal(payload) if err != nil { @@ -213,7 +213,7 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse, "msg", msg, "resume", resume, "prompt", prompt) payload := models.NewDSCompletionReq(prompt, chatBody.Model, - defaultLCPProps["temp"], cfg) + defaultLCPProps["temp"], cfg, chatBody.MakeStopSlice()) data, err := json.Marshal(payload) if err != nil { logger.Error("failed to form a msg", "error", err) diff --git a/models/models.go b/models/models.go index 918e35e..9514741 100644 --- a/models/models.go +++ b/models/models.go @@ -97,6 +97,18 @@ func (cb *ChatBody) ListRoles() []string { return resp } +func (cb *ChatBody) MakeStopSlice() []string { + namesMap := make(map[string]struct{}) + for _, m := range cb.Messages { + namesMap[m.Role] = struct{}{} + } + ss := []string{"<|im_end|>"} + for k := range namesMap { + ss = append(ss, k+":\n") + } + return ss +} + type ChatToolsBody struct { Model string `json:"model"` Messages []RoleMsg `json:"messages"` @@ -173,7 +185,7 @@ type DSCompletionReq struct { TopP float32 `json:"top_p"` } -func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq { +func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config, stopSlice []string) DSCompletionReq { return DSCompletionReq{ Model: model, Prompt: prompt, @@ -184,11 +196,7 @@ func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) PresencePenalty: 0, FrequencyPenalty: 0, TopP: 1.0, - Stop: []string{ - cfg.UserRole + ":\n", "<|im_end|>", - cfg.ToolRole + ":\n", - cfg.AssistantRole + ":\n", - }, + Stop: stopSlice, } } @@ -326,7 +334,7 @@ type LlamaCPPReq struct { // Samplers string `json:"samplers"` } -func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) LlamaCPPReq { +func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32, stopStrings []string) LlamaCPPReq { return LlamaCPPReq{ Stream: true, Prompt: prompt, @@ -336,11 +344,7 @@ func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) Llam DryMultiplier: props["dry_multiplier"], MinP: props["min_p"], NPredict: int32(props["n_predict"]), - Stop: []string{ - cfg.UserRole + ":\n", "<|im_end|>", - cfg.ToolRole + ":\n", - cfg.AssistantRole + ":\n", - }, + Stop: stopStrings, } } |