diff --git a/server/model.go b/server/model.go index be318db9..9e22d63a 100644 --- a/server/model.go +++ b/server/model.go @@ -327,7 +327,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { var kv map[string]string // execute the subtree with placeholders to identify the keys - if err := json.Unmarshal(b.Bytes(), &kv); err != nil { + // trim any commands that might exist in the template + if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { return nil, false } @@ -342,35 +343,26 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { } } - var sm []map[string]any - decoder := json.NewDecoder(strings.NewReader(s)) - for { - // incrementally decode the JSON into a list of JSON objects - // skipping over any invalid tokens - if err := decoder.Decode(&sm); err != nil { - if errors.Is(err, io.EOF) { - break - } - - if errors.As(err, new(*json.SyntaxError)) { - r := decoder.Buffered() - if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil { - break - } - - decoder = json.NewDecoder(r) - continue - } - + var objs []map[string]any + for offset := 0; offset < len(s); { + if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) { + break + } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { + // skip over any syntax errors + offset += int(syntax.Offset) + } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { + // skip over any unmarshalable types + offset += int(unmarshalType.Offset) + } else if err != nil { return nil, false + } else { + // break when an object is decoded + break } - - // break as soon as a valid object is decoded - break } var toolCalls []api.ToolCall - for _, kv := range sm { + for _, kv := range objs { call := api.ToolCall{ ID: uuid.New().String(), Type: "function", @@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { toolCalls = append(toolCalls, call) } - if len(toolCalls) > 0 { - return toolCalls, true - } - - return nil, false + return toolCalls, len(toolCalls) > 0 } diff --git a/server/model_test.go b/server/model_test.go index 02578192..d39f2891 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -136,11 +136,16 @@ func TestExecuteWithTools(t *testing.T) { cases := []struct { model string output string + ok bool }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}, +The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, + {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, + {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"command-r-plus", "Action: ```json" + ` [ { @@ -158,8 +163,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`} } } ] -` + "```"}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, +` + "```", true}, + {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, + {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, + {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, } var tools []api.Tool @@ -216,17 +223,19 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`} t.Run("parse", func(t *testing.T) { m := &Model{Template: tmpl} actual, ok := m.parseToolCalls(tt.output) - if !ok { - t.Fatal("failed to parse tool calls") + if ok != tt.ok { + t.Fatalf("expected %t, got %t", tt.ok, ok) } - for i := range actual { - // ID is randomly generated so clear it for comparison - actual[i].ID = "" - } + if tt.ok { + for i := range actual { + // ID is randomly generated so clear it for comparison + actual[i].ID = "" + } - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) + if diff := cmp.Diff(actual, calls); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } } }) })