summaryrefslogtreecommitdiff
path: root/bot_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'bot_test.go')
-rw-r--r--bot_test.go289
1 files changed, 289 insertions, 0 deletions
diff --git a/bot_test.go b/bot_test.go
new file mode 100644
index 0000000..d2956a9
--- /dev/null
+++ b/bot_test.go
@@ -0,0 +1,289 @@
+package main
+
+import (
+ "gf-lt/config"
+ "gf-lt/models"
+ "reflect"
+ "testing"
+)
+
+func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
+ // Mock config for testing
+ testCfg := &config.Config{
+ AssistantRole: "assistant",
+ WriteNextMsgAsCompletionAgent: "",
+ }
+ cfg = testCfg
+
+ tests := []struct {
+ name string
+ input []models.RoleMsg
+ expected []models.RoleMsg
+ }{
+ {
+ name: "no consecutive assistant messages",
+ input: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "Hi there"},
+ {Role: "user", Content: "How are you?"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "Hi there"},
+ {Role: "user", Content: "How are you?"},
+ },
+ },
+ {
+ name: "consecutive assistant messages should be consolidated",
+ input: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "First part"},
+ {Role: "assistant", Content: "Second part"},
+ {Role: "user", Content: "Thanks"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "First part\nSecond part"},
+ {Role: "user", Content: "Thanks"},
+ },
+ },
+ {
+ name: "multiple sets of consecutive assistant messages",
+ input: []models.RoleMsg{
+ {Role: "user", Content: "First question"},
+ {Role: "assistant", Content: "First answer part 1"},
+ {Role: "assistant", Content: "First answer part 2"},
+ {Role: "user", Content: "Second question"},
+ {Role: "assistant", Content: "Second answer part 1"},
+ {Role: "assistant", Content: "Second answer part 2"},
+ {Role: "assistant", Content: "Second answer part 3"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "user", Content: "First question"},
+ {Role: "assistant", Content: "First answer part 1\nFirst answer part 2"},
+ {Role: "user", Content: "Second question"},
+ {Role: "assistant", Content: "Second answer part 1\nSecond answer part 2\nSecond answer part 3"},
+ },
+ },
+ {
+ name: "single assistant message (no consolidation needed)",
+ input: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "Hi there"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "Hi there"},
+ },
+ },
+ {
+ name: "only assistant messages",
+ input: []models.RoleMsg{
+ {Role: "assistant", Content: "First"},
+ {Role: "assistant", Content: "Second"},
+ {Role: "assistant", Content: "Third"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "assistant", Content: "First\nSecond\nThird"},
+ },
+ },
+ {
+ name: "user messages at the end are preserved",
+ input: []models.RoleMsg{
+ {Role: "assistant", Content: "First"},
+ {Role: "assistant", Content: "Second"},
+ {Role: "user", Content: "Final user message"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "assistant", Content: "First\nSecond"},
+ {Role: "user", Content: "Final user message"},
+ },
+ },
+ {
+ name: "tool call ids preserved in consolidation",
+ input: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "First part", ToolCallID: "call_123"},
+ {Role: "assistant", Content: "Second part", ToolCallID: "call_123"}, // Same ID
+ {Role: "user", Content: "Thanks"},
+ },
+ expected: []models.RoleMsg{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: "First part\nSecond part", ToolCallID: "call_123"},
+ {Role: "user", Content: "Thanks"},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := consolidateConsecutiveAssistantMessages(tt.input)
+
+ if len(result) != len(tt.expected) {
+ t.Errorf("Expected %d messages, got %d", len(tt.expected), len(result))
+ t.Logf("Result: %+v", result)
+ t.Logf("Expected: %+v", tt.expected)
+ return
+ }
+
+ for i, expectedMsg := range tt.expected {
+ if i >= len(result) {
+ t.Errorf("Result has fewer messages than expected at index %d", i)
+ continue
+ }
+
+ actualMsg := result[i]
+ if actualMsg.Role != expectedMsg.Role {
+ t.Errorf("Message %d: expected role '%s', got '%s'", i, expectedMsg.Role, actualMsg.Role)
+ }
+
+ if actualMsg.Content != expectedMsg.Content {
+ t.Errorf("Message %d: expected content '%s', got '%s'", i, expectedMsg.Content, actualMsg.Content)
+ }
+
+ if actualMsg.ToolCallID != expectedMsg.ToolCallID {
+ t.Errorf("Message %d: expected ToolCallID '%s', got '%s'", i, expectedMsg.ToolCallID, actualMsg.ToolCallID)
+ }
+ }
+
+ // Additional check: ensure no messages were lost
+ if !reflect.DeepEqual(result, tt.expected) {
+ t.Errorf("Result does not match expected:\nResult: %+v\nExpected: %+v", result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestUnmarshalFuncCall(t *testing.T) {
+ tests := []struct {
+ name string
+ jsonStr string
+ want *models.FuncCall
+ wantErr bool
+ }{
+ {
+ name: "simple websearch with numeric limit",
+ jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`,
+ want: &models.FuncCall{
+ Name: "websearch",
+ Args: map[string]string{"query": "current weather in London", "limit": "3"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "string limit",
+ jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`,
+ want: &models.FuncCall{
+ Name: "websearch",
+ Args: map[string]string{"query": "test", "limit": "5"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "boolean arg",
+ jsonStr: `{"name": "test", "args": {"flag": true}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"flag": "true"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "null arg",
+ jsonStr: `{"name": "test", "args": {"opt": null}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"opt": ""},
+ },
+ wantErr: false,
+ },
+ {
+ name: "float arg",
+ jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"ratio": "0.5"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ jsonStr: `{invalid}`,
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := unmarshalFuncCall(tt.jsonStr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if got.Name != tt.want.Name {
+ t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name)
+ }
+ if len(got.Args) != len(tt.want.Args) {
+ t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args))
+ }
+ for k, v := range tt.want.Args {
+ if got.Args[k] != v {
+ t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v)
+ }
+ }
+ })
+ }
+}
+
+func TestConvertJSONToMapStringString(t *testing.T) {
+ tests := []struct {
+ name string
+ jsonStr string
+ want map[string]string
+ wantErr bool
+ }{
+ {
+ name: "simple map",
+ jsonStr: `{"query": "weather", "limit": 5}`,
+ want: map[string]string{"query": "weather", "limit": "5"},
+ wantErr: false,
+ },
+ {
+ name: "boolean and null",
+ jsonStr: `{"flag": true, "opt": null}`,
+ want: map[string]string{"flag": "true", "opt": ""},
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ jsonStr: `{invalid`,
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := convertJSONToMapStringString(tt.jsonStr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if len(got) != len(tt.want) {
+ t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want))
+ }
+ for k, v := range tt.want {
+ if got[k] != v {
+ t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v)
+ }
+ }
+ })
+ }
+} \ No newline at end of file