summaryrefslogtreecommitdiff
path: root/models/models.go
diff options
context:
space:
mode:
Diffstat (limited to 'models/models.go')
-rw-r--r--models/models.go28
1 files changed, 16 insertions, 12 deletions
diff --git a/models/models.go b/models/models.go
index 918e35e..9514741 100644
--- a/models/models.go
+++ b/models/models.go
@@ -97,6 +97,18 @@ func (cb *ChatBody) ListRoles() []string {
return resp
}
+func (cb *ChatBody) MakeStopSlice() []string {
+ namesMap := make(map[string]struct{})
+ for _, m := range cb.Messages {
+ namesMap[m.Role] = struct{}{}
+ }
+ ss := []string{"<|im_end|>"}
+ for k := range namesMap {
+ ss = append(ss, k+":\n")
+ }
+ return ss
+}
+
type ChatToolsBody struct {
Model string `json:"model"`
Messages []RoleMsg `json:"messages"`
@@ -173,7 +185,7 @@ type DSCompletionReq struct {
TopP float32 `json:"top_p"`
}
-func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config) DSCompletionReq {
+func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config, stopSlice []string) DSCompletionReq {
return DSCompletionReq{
Model: model,
Prompt: prompt,
@@ -184,11 +196,7 @@ func NewDSCompletionReq(prompt, model string, temp float32, cfg *config.Config)
PresencePenalty: 0,
FrequencyPenalty: 0,
TopP: 1.0,
- Stop: []string{
- cfg.UserRole + ":\n", "<|im_end|>",
- cfg.ToolRole + ":\n",
- cfg.AssistantRole + ":\n",
- },
+ Stop: stopSlice,
}
}
@@ -326,7 +334,7 @@ type LlamaCPPReq struct {
// Samplers string `json:"samplers"`
}
-func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) LlamaCPPReq {
+func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32, stopStrings []string) LlamaCPPReq {
return LlamaCPPReq{
Stream: true,
Prompt: prompt,
@@ -336,11 +344,7 @@ func NewLCPReq(prompt string, cfg *config.Config, props map[string]float32) Llam
DryMultiplier: props["dry_multiplier"],
MinP: props["min_p"],
NPredict: int32(props["n_predict"]),
- Stop: []string{
- cfg.UserRole + ":\n", "<|im_end|>",
- cfg.ToolRole + ":\n",
- cfg.AssistantRole + ":\n",
- },
+ Stop: stopStrings,
}
}