summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llm.go8
-rw-r--r--models/models.go28
2 files changed, 20 insertions, 16 deletions
diff --git a/llm.go b/llm.go
index 3307467..046d28d 100644
--- a/llm.go
+++ b/llm.go
@@ -2,8 +2,8 @@ package main
import (
"bytes"
- "gf-lt/models"
"encoding/json"
+ "gf-lt/models"
"io"
"strings"
)
@@ -88,10 +88,10 @@ func (lcp LlamaCPPeer) FormMsg(msg, role string, resume bool) (io.Reader, error)
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
var payload any
- payload = models.NewLCPReq(prompt, cfg, defaultLCPProps)
+ payload = models.NewLCPReq(prompt, cfg, defaultLCPProps, chatBody.MakeStopSlice())
if strings.Contains(chatBody.Model, "deepseek") {
payload = models.NewDSCompletionReq(prompt, chatBody.Model,
- defaultLCPProps["temp"], cfg)
+ defaultLCPProps["temp"], cfg, chatBody.MakeStopSlice())
}
data, err := json.Marshal(payload)
if err != nil {
@@ -213,7 +213,7 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
logger.Debug("checking prompt for /completion", "tool_use", cfg.ToolUse,
"msg", msg, "resume", resume, "prompt", prompt)
payload := models.NewDSCompletionReq(prompt, chatBody.Model,
- defaultLCPProps["temp"], cfg)
+ defaultLCPProps["temp"], cfg, chatBody.MakeStopSlice())
data, err := json.Marshal(payload)
if err != nil {
logger.Error("failed to form a msg", "error", err)
diff --git a/models/models.go b/models/models.go
index 918e35e..9514741 100644
--- a/models/models.go
+++ b/models/models.go
@@ -97,6 +97,18 @@ func (cb *ChatBody) ListRoles() []string {
return resp
}
+func (cb *ChatBody) MakeStopSlice() []string {
+ namesMap := make(map[string]struct{})
+ for _, m := range cb.Messages {
+ namesMap[m.Role] = struct{}{}
+ }
+ ss := []string{"<|im_end|>"}
+ for k := range namesMap {
+ ss = append(ss, k+":\n")
+ }
+ return ss
+}
+
type ChatToolsBody struct {
Model string `json:"model"`
Messages []RoleMsg `json:"messages"`
@@ -173,7 +185,7 @@ type DSCompletionReq struct {
TopP float32 `json:"top_p"`
}
-func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
+func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config, stopSlice []string) DSCompletionReq {
return DSCompletionReq{
Model: model,
Prompt: prompt,
@@ -184,11 +196,7 @@ func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config)
PresencePenalty: 0,
FrequencyPenalty: 0,
TopP: 1.0,
- Stop: []string{
- cfg.UserRole + ":\n", "<|im_end|>",
- cfg.ToolRole + ":\n",
- cfg.AssistantRole + ":\n",
- },
+ Stop: stopSlice,
}
}
@@ -326,7 +334,7 @@ type LlamaCPPReq struct {
// Samplers string `json:"samplers"`
}
-func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) LlamaCPPReq {
+func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32, stopStrings []string) LlamaCPPReq {
return LlamaCPPReq{
Stream: true,
Prompt: prompt,
@@ -336,11 +344,7 @@ func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) Llam
DryMultiplier: props["dry_multiplier"],
MinP: props["min_p"],
NPredict: int32(props["n_predict"]),
- Stop: []string{
- cfg.UserRole + ":\n", "<|im_end|>",
- cfg.ToolRole + ":\n",
- cfg.AssistantRole + ":\n",
- },
+ Stop: stopStrings,
}
}