diff options
Diffstat (limited to 'models/models.go')
-rw-r--r-- | models/models.go | 79 |
1 files changed, 63 insertions, 16 deletions
diff --git a/models/models.go b/models/models.go index 67cff0c..526d056 100644 --- a/models/models.go +++ b/models/models.go @@ -57,28 +57,33 @@ type RoleMsg struct { } func (m RoleMsg) ToText(i int, cfg *config.Config) string { - icon := "" - switch m.Role { - case "assistant": - icon = fmt.Sprintf("(%d) %s", i, cfg.AssistantIcon) - case "user": - icon = fmt.Sprintf("(%d) %s", i, cfg.UserIcon) - case "system": - icon = fmt.Sprintf("(%d) <system>: ", i) - case "tool": - icon = fmt.Sprintf("(%d) %s", i, cfg.ToolIcon) - default: - icon = fmt.Sprintf("(%d) <%s>: ", i, m.Role) + icon := fmt.Sprintf("(%d)", i) + if !strings.HasPrefix(m.Content, cfg.UserRole+":") && !strings.HasPrefix(m.Content, cfg.AssistantRole+":") { + switch m.Role { + case "assistant": + icon = fmt.Sprintf("(%d) %s", i, cfg.AssistantIcon) + case "user": + icon = fmt.Sprintf("(%d) %s", i, cfg.UserIcon) + case "system": + icon = fmt.Sprintf("(%d) <system>: ", i) + case "tool": + icon = fmt.Sprintf("(%d) %s", i, cfg.ToolIcon) + default: + icon = fmt.Sprintf("(%d) <%s>: ", i, m.Role) + } } textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, m.Content) return strings.ReplaceAll(textMsg, "\n\n", "\n") } +func (m RoleMsg) ToPrompt() string { + return strings.ReplaceAll(fmt.Sprintf("%s:\n%s", m.Role, m.Content), "\n\n", "\n") +} + type ChatBody struct { - Model string `json:"model"` - Stream bool `json:"stream"` - Messages []RoleMsg `json:"messages"` - DRYMultiplier float32 `json:"frequency_penalty"` + Model string `json:"model"` + Stream bool `json:"stream"` + Messages []RoleMsg `json:"messages"` } type ChatToolsBody struct { @@ -144,3 +149,45 @@ type LLMModels struct { } `json:"meta"` } `json:"data"` } + +type LlamaCPPReq struct { + Stream bool `json:"stream"` + // Messages []RoleMsg `json:"messages"` + Prompt string `json:"prompt"` + Temperature float32 `json:"temperature"` + DryMultiplier float32 `json:"dry_multiplier"` + Stop []string `json:"stop"` + // MaxTokens int `json:"max_tokens"` + // DryBase float64 `json:"dry_base"` + // DryAllowedLength int `json:"dry_allowed_length"` + // DryPenaltyLastN int `json:"dry_penalty_last_n"` + // CachePrompt bool `json:"cache_prompt"` + // DynatempRange int `json:"dynatemp_range"` + // DynatempExponent int `json:"dynatemp_exponent"` + // TopK int `json:"top_k"` + // TopP float32 `json:"top_p"` + // MinP float32 `json:"min_p"` + // TypicalP int `json:"typical_p"` + // XtcProbability int `json:"xtc_probability"` + // XtcThreshold float32 `json:"xtc_threshold"` + // RepeatLastN int `json:"repeat_last_n"` + // RepeatPenalty int `json:"repeat_penalty"` + // PresencePenalty int `json:"presence_penalty"` + // FrequencyPenalty int `json:"frequency_penalty"` + // Samplers string `json:"samplers"` +} + +func NewLCPReq(prompt, role string) LlamaCPPReq { + return LlamaCPPReq{ + Stream: true, + Prompt: prompt, + Temperature: 0.8, + DryMultiplier: 0.5, + Stop: []string{role + ":\n", "<|im_end|>"}, + } +} + +type LlamaCPPResp struct { + Content string `json:"content"` + Stop bool `json:"stop"` +} |