summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/card.go18
-rw-r--r--models/db.go2
-rw-r--r--models/models.go146
-rw-r--r--models/openrouter.go11
-rw-r--r--models/openrouter_test.go97
5 files changed, 207 insertions, 67 deletions
diff --git a/models/card.go b/models/card.go
index adfb030..9bf6665 100644
--- a/models/card.go
+++ b/models/card.go
@@ -31,18 +31,20 @@ func (c *CharCardSpec) Simplify(userName, fpath string) *CharCard {
fm := strings.ReplaceAll(strings.ReplaceAll(c.FirstMes, "{{char}}", c.Name), "{{user}}", userName)
sysPr := strings.ReplaceAll(strings.ReplaceAll(c.Description, "{{char}}", c.Name), "{{user}}", userName)
return &CharCard{
- SysPrompt: sysPr,
- FirstMsg: fm,
- Role: c.Name,
- FilePath: fpath,
+ SysPrompt: sysPr,
+ FirstMsg: fm,
+ Role: c.Name,
+ FilePath: fpath,
+ Characters: []string{c.Name, userName},
}
}
type CharCard struct {
- SysPrompt string `json:"sys_prompt"`
- FirstMsg string `json:"first_msg"`
- Role string `json:"role"`
- FilePath string `json:"filepath"`
+ SysPrompt string `json:"sys_prompt"`
+ FirstMsg string `json:"first_msg"`
+ Role string `json:"role"`
+ Characters []string `json:"chars"`
+ FilePath string `json:"filepath"`
}
func (cc *CharCard) ToSpec(userName string) *CharCardSpec {
diff --git a/models/db.go b/models/db.go
index 090f46d..73a0b53 100644
--- a/models/db.go
+++ b/models/db.go
@@ -14,7 +14,7 @@ type Chat struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
-func (c Chat) ToHistory() ([]RoleMsg, error) {
+func (c *Chat) ToHistory() ([]RoleMsg, error) {
resp := []RoleMsg{}
if err := json.Unmarshal([]byte(c.Msgs), &resp); err != nil {
return nil, err
diff --git a/models/models.go b/models/models.go
index 912f72b..d15e0d1 100644
--- a/models/models.go
+++ b/models/models.go
@@ -89,37 +89,42 @@ type ImageContentPart struct {
// RoleMsg represents a message with content that can be either a simple string or structured content parts
type RoleMsg struct {
- Role string `json:"role"`
- Content string `json:"-"`
- ContentParts []interface{} `json:"-"`
- ToolCallID string `json:"tool_call_id,omitempty"` // For tool response messages
- hasContentParts bool // Flag to indicate which content type to marshal
+ Role string `json:"role"`
+ Content string `json:"-"`
+ ContentParts []any `json:"-"`
+ ToolCallID string `json:"tool_call_id,omitempty"` // For tool response messages
+ KnownTo []string `json:"known_to,omitempty"`
+ hasContentParts bool // Flag to indicate which content type to marshal
}
// MarshalJSON implements custom JSON marshaling for RoleMsg
-func (m RoleMsg) MarshalJSON() ([]byte, error) {
+func (m *RoleMsg) MarshalJSON() ([]byte, error) {
if m.hasContentParts {
// Use structured content format
aux := struct {
- Role string `json:"role"`
- Content []interface{} `json:"content"`
- ToolCallID string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content []any `json:"content"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ KnownTo []string `json:"known_to,omitempty"`
}{
Role: m.Role,
Content: m.ContentParts,
ToolCallID: m.ToolCallID,
+ KnownTo: m.KnownTo,
}
return json.Marshal(aux)
} else {
// Use simple content format
aux := struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ToolCallID string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ KnownTo []string `json:"known_to,omitempty"`
}{
Role: m.Role,
Content: m.Content,
ToolCallID: m.ToolCallID,
+ KnownTo: m.KnownTo,
}
return json.Marshal(aux)
}
@@ -129,23 +134,26 @@ func (m RoleMsg) MarshalJSON() ([]byte, error) {
func (m *RoleMsg) UnmarshalJSON(data []byte) error {
// First, try to unmarshal as structured content format
var structured struct {
- Role string `json:"role"`
- Content []interface{} `json:"content"`
- ToolCallID string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content []any `json:"content"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ KnownTo []string `json:"known_to,omitempty"`
}
if err := json.Unmarshal(data, &structured); err == nil && len(structured.Content) > 0 {
m.Role = structured.Role
m.ContentParts = structured.Content
m.ToolCallID = structured.ToolCallID
+ m.KnownTo = structured.KnownTo
m.hasContentParts = true
return nil
}
// Otherwise, unmarshal as simple content format
var simple struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ToolCallID string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ KnownTo []string `json:"known_to,omitempty"`
}
if err := json.Unmarshal(data, &simple); err != nil {
return err
@@ -153,22 +161,21 @@ func (m *RoleMsg) UnmarshalJSON(data []byte) error {
m.Role = simple.Role
m.Content = simple.Content
m.ToolCallID = simple.ToolCallID
+ m.KnownTo = simple.KnownTo
m.hasContentParts = false
return nil
}
-func (m RoleMsg) ToText(i int) string {
- icon := fmt.Sprintf("(%d)", i)
-
+func (m *RoleMsg) ToText(i int) string {
// Convert content to string representation
- contentStr := ""
+ var contentStr string
if !m.hasContentParts {
contentStr = m.Content
} else {
// For structured content, just take the text parts
var textParts []string
for _, part := range m.ContentParts {
- if partMap, ok := part.(map[string]interface{}); ok {
+ if partMap, ok := part.(map[string]any); ok {
if partType, exists := partMap["type"]; exists && partType == "text" {
if textVal, textExists := partMap["text"]; textExists {
if textStr, isStr := textVal.(string); isStr {
@@ -180,24 +187,26 @@ func (m RoleMsg) ToText(i int) string {
}
contentStr = strings.Join(textParts, " ") + " "
}
-
// check if already has role annotation (/completion makes them)
- if !strings.HasPrefix(contentStr, m.Role+":") {
- icon = fmt.Sprintf("(%d) <%s>: ", i, m.Role)
- }
+ // in that case remove it, and then add to icon
+ // since icon and content are separated by \n
+ contentStr, _ = strings.CutPrefix(contentStr, m.Role+":")
+ // if !strings.HasPrefix(contentStr, m.Role+":") {
+ icon := fmt.Sprintf("(%d) <%s>: ", i, m.Role)
+ // }
textMsg := fmt.Sprintf("[-:-:b]%s[-:-:-]\n%s\n", icon, contentStr)
return strings.ReplaceAll(textMsg, "\n\n", "\n")
}
-func (m RoleMsg) ToPrompt() string {
- contentStr := ""
+func (m *RoleMsg) ToPrompt() string {
+ var contentStr string
if !m.hasContentParts {
contentStr = m.Content
} else {
// For structured content, just take the text parts
var textParts []string
for _, part := range m.ContentParts {
- if partMap, ok := part.(map[string]interface{}); ok {
+ if partMap, ok := part.(map[string]any); ok {
if partType, exists := partMap["type"]; exists && partType == "text" {
if textVal, textExists := partMap["text"]; textExists {
if textStr, isStr := textVal.(string); isStr {
@@ -222,7 +231,7 @@ func NewRoleMsg(role, content string) RoleMsg {
}
// NewMultimodalMsg creates a RoleMsg with structured content parts (text and images)
-func NewMultimodalMsg(role string, contentParts []interface{}) RoleMsg {
+func NewMultimodalMsg(role string, contentParts []any) RoleMsg {
return RoleMsg{
Role: role,
ContentParts: contentParts,
@@ -231,7 +240,7 @@ func NewMultimodalMsg(role string, contentParts []interface{}) RoleMsg {
}
// HasContent returns true if the message has either string content or structured content parts
-func (m RoleMsg) HasContent() bool {
+func (m *RoleMsg) HasContent() bool {
if m.Content != "" {
return true
}
@@ -242,22 +251,23 @@ func (m RoleMsg) HasContent() bool {
}
// IsContentParts returns true if the message uses structured content parts
-func (m RoleMsg) IsContentParts() bool {
+func (m *RoleMsg) IsContentParts() bool {
return m.hasContentParts
}
// GetContentParts returns the content parts of the message
-func (m RoleMsg) GetContentParts() []interface{} {
+func (m *RoleMsg) GetContentParts() []any {
return m.ContentParts
}
// Copy creates a copy of the RoleMsg with all fields
-func (m RoleMsg) Copy() RoleMsg {
+func (m *RoleMsg) Copy() RoleMsg {
return RoleMsg{
Role: m.Role,
Content: m.Content,
ContentParts: m.ContentParts,
ToolCallID: m.ToolCallID,
+ KnownTo: m.KnownTo,
hasContentParts: m.hasContentParts,
}
}
@@ -267,9 +277,9 @@ func (m *RoleMsg) AddTextPart(text string) {
if !m.hasContentParts {
// Convert to content parts format
if m.Content != "" {
- m.ContentParts = []interface{}{TextContentPart{Type: "text", Text: m.Content}}
+ m.ContentParts = []any{TextContentPart{Type: "text", Text: m.Content}}
} else {
- m.ContentParts = []interface{}{}
+ m.ContentParts = []any{}
}
m.hasContentParts = true
}
@@ -283,9 +293,9 @@ func (m *RoleMsg) AddImagePart(imageURL string) {
if !m.hasContentParts {
// Convert to content parts format
if m.Content != "" {
- m.ContentParts = []interface{}{TextContentPart{Type: "text", Text: m.Content}}
+ m.ContentParts = []any{TextContentPart{Type: "text", Text: m.Content}}
} else {
- m.ContentParts = []interface{}{}
+ m.ContentParts = []any{}
}
m.hasContentParts = true
}
@@ -359,13 +369,27 @@ func (cb *ChatBody) ListRoles() []string {
}
func (cb *ChatBody) MakeStopSlice() []string {
- namesMap := make(map[string]struct{})
- for _, m := range cb.Messages {
- namesMap[m.Role] = struct{}{}
- }
- ss := []string{"<|im_end|>"}
- for k := range namesMap {
- ss = append(ss, k+":\n")
+ return cb.MakeStopSliceExcluding("", cb.ListRoles())
+}
+
+func (cb *ChatBody) MakeStopSliceExcluding(
+ excludeRole string, roleList []string,
+) []string {
+ ss := []string{}
+ for _, role := range roleList {
+ // Skip the excluded role (typically the current speaker)
+ if role == excludeRole {
+ continue
+ }
+ // Add multiple variations to catch different formatting
+ ss = append(ss,
+ role+":\n", // Most common: role with newline
+ role+":", // Role with colon but no newline
+ role+": ", // Role with colon and single space
+ role+": ", // Role with colon and double space (common tokenization)
+ role+": \n", // Role with colon and double space (common tokenization)
+ role+": ", // Role with colon and triple space
+ )
}
return ss
}
@@ -443,12 +467,12 @@ type LlamaCPPReq struct {
Stream bool `json:"stream"`
// For multimodal requests, prompt should be an object with prompt_string and multimodal_data
// For regular requests, prompt is a string
- Prompt interface{} `json:"prompt"` // Can be string or object with prompt_string and multimodal_data
- Temperature float32 `json:"temperature"`
- DryMultiplier float32 `json:"dry_multiplier"`
- Stop []string `json:"stop"`
- MinP float32 `json:"min_p"`
- NPredict int32 `json:"n_predict"`
+ Prompt any `json:"prompt"` // Can be string or object with prompt_string and multimodal_data
+ Temperature float32 `json:"temperature"`
+ DryMultiplier float32 `json:"dry_multiplier"`
+ Stop []string `json:"stop"`
+ MinP float32 `json:"min_p"`
+ NPredict int32 `json:"n_predict"`
// MaxTokens int `json:"max_tokens"`
// DryBase float64 `json:"dry_base"`
// DryAllowedLength int `json:"dry_allowed_length"`
@@ -476,7 +500,7 @@ type PromptObject struct {
}
func NewLCPReq(prompt, model string, multimodalData []string, props map[string]float32, stopStrings []string) LlamaCPPReq {
- var finalPrompt interface{}
+ var finalPrompt any
if len(multimodalData) > 0 {
// When multimodal data is present, use the object format as per Python example:
// { "prompt": { "prompt_string": "...", "multimodal_data": [...] } }
@@ -523,9 +547,23 @@ type LCPModels struct {
}
func (lcp *LCPModels) ListModels() []string {
- resp := []string{}
+ resp := make([]string, 0, len(lcp.Data))
for _, model := range lcp.Data {
resp = append(resp, model.ID)
}
return resp
}
+
+type ChatRoundReq struct {
+ UserMsg string
+ Role string
+ Regen bool
+ Resume bool
+}
+
+type APIType int
+
+const (
+ APITypeChat APIType = iota
+ APITypeCompletion
+)
diff --git a/models/openrouter.go b/models/openrouter.go
index 50f26b6..6196498 100644
--- a/models/openrouter.go
+++ b/models/openrouter.go
@@ -143,11 +143,14 @@ type ORModels struct {
func (orm *ORModels) ListModels(free bool) []string {
resp := []string{}
- for _, model := range orm.Data {
+ for i := range orm.Data {
+ model := &orm.Data[i] // Take address of element to avoid copying
if free {
- if model.Pricing.Prompt == "0" && model.Pricing.Request == "0" &&
- model.Pricing.Completion == "0" {
- resp = append(resp, model.ID)
+ if model.Pricing.Prompt == "0" && model.Pricing.Completion == "0" {
+ // treat missing request as free
+ if model.Pricing.Request == "" || model.Pricing.Request == "0" {
+ resp = append(resp, model.ID)
+ }
}
} else {
resp = append(resp, model.ID)
diff --git a/models/openrouter_test.go b/models/openrouter_test.go
new file mode 100644
index 0000000..dd38d23
--- /dev/null
+++ b/models/openrouter_test.go
@@ -0,0 +1,97 @@
+package models
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestORModelsListModels(t *testing.T) {
+ t.Run("unit test with hardcoded data", func(t *testing.T) {
+ jsonData := `{
+ "data": [
+ {
+ "id": "model/free",
+ "pricing": {
+ "prompt": "0",
+ "completion": "0"
+ }
+ },
+ {
+ "id": "model/paid",
+ "pricing": {
+ "prompt": "0.001",
+ "completion": "0.002"
+ }
+ },
+ {
+ "id": "model/request-zero",
+ "pricing": {
+ "prompt": "0",
+ "completion": "0",
+ "request": "0"
+ }
+ },
+ {
+ "id": "model/request-nonzero",
+ "pricing": {
+ "prompt": "0",
+ "completion": "0",
+ "request": "0.5"
+ }
+ }
+ ]
+ }`
+ var models ORModels
+ if err := json.Unmarshal([]byte(jsonData), &models); err != nil {
+ t.Fatalf("failed to unmarshal test data: %v", err)
+ }
+ freeModels := models.ListModels(true)
+ if len(freeModels) != 2 {
+ t.Errorf("expected 2 free models, got %d: %v", len(freeModels), freeModels)
+ }
+ expectedFree := map[string]bool{"model/free": true, "model/request-zero": true}
+ for _, id := range freeModels {
+ if !expectedFree[id] {
+ t.Errorf("unexpected free model ID: %s", id)
+ }
+ }
+ allModels := models.ListModels(false)
+ if len(allModels) != 4 {
+ t.Errorf("expected 4 total models, got %d", len(allModels))
+ }
+ })
+
+ t.Run("integration with or_models.json", func(t *testing.T) {
+ // Attempt to load the real data file from the project root
+ path := filepath.Join("..", "or_models.json")
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Skip("or_models.json not found, skipping integration test")
+ }
+ var models ORModels
+ if err := json.Unmarshal(data, &models); err != nil {
+ t.Fatalf("failed to unmarshal %s: %v", path, err)
+ }
+ freeModels := models.ListModels(true)
+ if len(freeModels) == 0 {
+ t.Error("expected at least one free model, got none")
+ }
+ allModels := models.ListModels(false)
+ if len(allModels) == 0 {
+ t.Error("expected at least one model")
+ }
+ // Ensure free models are subset of all models
+ freeSet := make(map[string]bool)
+ for _, id := range freeModels {
+ freeSet[id] = true
+ }
+ for _, id := range freeModels {
+ if !freeSet[id] {
+ t.Errorf("free model %s not found in all models", id)
+ }
+ }
+ t.Logf("found %d free models out of %d total models", len(freeModels), len(allModels))
+ })
+} \ No newline at end of file