summaryrefslogtreecommitdiff
path: root/bot_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'bot_test.go')
-rw-r--r--bot_test.go134
1 files changed, 134 insertions, 0 deletions
diff --git a/bot_test.go b/bot_test.go
index 2d59c3c..d2956a9 100644
--- a/bot_test.go
+++ b/bot_test.go
@@ -152,4 +152,138 @@ func TestConsolidateConsecutiveAssistantMessages(t *testing.T) {
}
})
}
+}
+
+func TestUnmarshalFuncCall(t *testing.T) {
+ tests := []struct {
+ name string
+ jsonStr string
+ want *models.FuncCall
+ wantErr bool
+ }{
+ {
+ name: "simple websearch with numeric limit",
+ jsonStr: `{"name": "websearch", "args": {"query": "current weather in London", "limit": 3}}`,
+ want: &models.FuncCall{
+ Name: "websearch",
+ Args: map[string]string{"query": "current weather in London", "limit": "3"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "string limit",
+ jsonStr: `{"name": "websearch", "args": {"query": "test", "limit": "5"}}`,
+ want: &models.FuncCall{
+ Name: "websearch",
+ Args: map[string]string{"query": "test", "limit": "5"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "boolean arg",
+ jsonStr: `{"name": "test", "args": {"flag": true}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"flag": "true"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "null arg",
+ jsonStr: `{"name": "test", "args": {"opt": null}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"opt": ""},
+ },
+ wantErr: false,
+ },
+ {
+ name: "float arg",
+ jsonStr: `{"name": "test", "args": {"ratio": 0.5}}`,
+ want: &models.FuncCall{
+ Name: "test",
+ Args: map[string]string{"ratio": "0.5"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ jsonStr: `{invalid}`,
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := unmarshalFuncCall(tt.jsonStr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("unmarshalFuncCall() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if got.Name != tt.want.Name {
+ t.Errorf("unmarshalFuncCall() name = %v, want %v", got.Name, tt.want.Name)
+ }
+ if len(got.Args) != len(tt.want.Args) {
+ t.Errorf("unmarshalFuncCall() args length = %v, want %v", len(got.Args), len(tt.want.Args))
+ }
+ for k, v := range tt.want.Args {
+ if got.Args[k] != v {
+ t.Errorf("unmarshalFuncCall() args[%v] = %v, want %v", k, got.Args[k], v)
+ }
+ }
+ })
+ }
+}
+
+func TestConvertJSONToMapStringString(t *testing.T) {
+ tests := []struct {
+ name string
+ jsonStr string
+ want map[string]string
+ wantErr bool
+ }{
+ {
+ name: "simple map",
+ jsonStr: `{"query": "weather", "limit": 5}`,
+ want: map[string]string{"query": "weather", "limit": "5"},
+ wantErr: false,
+ },
+ {
+ name: "boolean and null",
+ jsonStr: `{"flag": true, "opt": null}`,
+ want: map[string]string{"flag": "true", "opt": ""},
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ jsonStr: `{invalid`,
+ want: nil,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := convertJSONToMapStringString(tt.jsonStr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("convertJSONToMapStringString() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if len(got) != len(tt.want) {
+ t.Errorf("convertJSONToMapStringString() length = %v, want %v", len(got), len(tt.want))
+ }
+ for k, v := range tt.want {
+ if got[k] != v {
+ t.Errorf("convertJSONToMapStringString()[%v] = %v, want %v", k, got[k], v)
+ }
+ }
+ })
+ }
} \ No newline at end of file