summaryrefslogtreecommitdiff
path: root/models/models.go
diff options
context:
space:
mode:
Diffstat (limited to 'models/models.go')
-rw-r--r--models/models.go79
1 files changed, 63 insertions, 16 deletions
diff --git a/models/models.go b/models/models.go
index 67cff0c..526d056 100644
--- a/models/models.go
+++ b/models/models.go
@@ -57,28 +57,33 @@ type RoleMsg struct {
}
func (m RoleMsg) ToText(i int, cfg *config.Config) string {
- icon := ""
- switch m.Role {
- case "assistant":
- icon = fmt.Sprintf("(%d) %s", i, cfg.AssistantIcon)
- case "user":
- icon = fmt.Sprintf("(%d) %s", i, cfg.UserIcon)
- case "system":
- icon = fmt.Sprintf("(%d) <system>: ", i)
- case "tool":
- icon = fmt.Sprintf("(%d) %s", i, cfg.ToolIcon)
- default:
- icon = fmt.Sprintf("(%d) <%s>: ", i, m.Role)
+ icon := fmt.Sprintf("(%d)", i)
+ if !strings.HasPrefix(m.Content, cfg.UserRole+":") && !strings.HasPrefix(m.Content, cfg.AssistantRole+":") {
+ switch m.Role {
+ case "assistant":
+ icon = fmt.Sprintf("(%d) %s", i, cfg.AssistantIcon)
+ case "user":
+ icon = fmt.Sprintf("(%d) %s", i, cfg.UserIcon)
+ case "system":
+ icon = fmt.Sprintf("(%d) <system>: ", i)
+ case "tool":
+ icon = fmt.Sprintf("(%d) %s", i, cfg.ToolIcon)
+ default:
+ icon = fmt.Sprintf("(%d) <%s>: ", i, m.Role)
+ }
}
textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, m.Content)
return strings.ReplaceAll(textMsg, "\n\n", "\n")
}
+func (m RoleMsg) ToPrompt() string {
+ return strings.ReplaceAll(fmt.Sprintf("%s:\n%s", m.Role, m.Content), "\n\n", "\n")
+}
+
type ChatBody struct {
- Model string `json:"model"`
- Stream bool `json:"stream"`
- Messages []RoleMsg `json:"messages"`
- DRYMultiplier float32 `json:"frequency_penalty"`
+ Model string `json:"model"`
+ Stream bool `json:"stream"`
+ Messages []RoleMsg `json:"messages"`
}
type ChatToolsBody struct {
@@ -144,3 +149,45 @@ type LLMModels struct {
} `json:"meta"`
} `json:"data"`
}
+
+type LlamaCPPReq struct {
+ Stream bool `json:"stream"`
+ // Messages []RoleMsg `json:"messages"`
+ Prompt string `json:"prompt"`
+ Temperature float32 `json:"temperature"`
+ DryMultiplier float32 `json:"dry_multiplier"`
+ Stop []string `json:"stop"`
+ // MaxTokens int `json:"max_tokens"`
+ // DryBase float64 `json:"dry_base"`
+ // DryAllowedLength int `json:"dry_allowed_length"`
+ // DryPenaltyLastN int `json:"dry_penalty_last_n"`
+ // CachePrompt bool `json:"cache_prompt"`
+ // DynatempRange int `json:"dynatemp_range"`
+ // DynatempExponent int `json:"dynatemp_exponent"`
+ // TopK int `json:"top_k"`
+ // TopP float32 `json:"top_p"`
+ // MinP float32 `json:"min_p"`
+ // TypicalP int `json:"typical_p"`
+ // XtcProbability int `json:"xtc_probability"`
+ // XtcThreshold float32 `json:"xtc_threshold"`
+ // RepeatLastN int `json:"repeat_last_n"`
+ // RepeatPenalty int `json:"repeat_penalty"`
+ // PresencePenalty int `json:"presence_penalty"`
+ // FrequencyPenalty int `json:"frequency_penalty"`
+ // Samplers string `json:"samplers"`
+}
+
+func NewLCPReq(prompt, role string) LlamaCPPReq {
+ return LlamaCPPReq{
+ Stream: true,
+ Prompt: prompt,
+ Temperature: 0.8,
+ DryMultiplier: 0.5,
+ Stop: []string{role + ":\n", "<|im_end|>"},
+ }
+}
+
+type LlamaCPPResp struct {
+ Content string `json:"content"`
+ Stop bool `json:"stop"`
+}