Merge pull request #5726 from ollama/mxyng/tools-templates

fix unmarshal type errors
This commit is contained in:
Michael Yang 2024-07-16 12:12:10 -07:00 committed by GitHub
commit a8388beb94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 39 additions and 42 deletions

View file

@ -327,7 +327,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
var kv map[string]string var kv map[string]string
// execute the subtree with placeholders to identify the keys // execute the subtree with placeholders to identify the keys
if err := json.Unmarshal(b.Bytes(), &kv); err != nil { // trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false return nil, false
} }
@ -342,35 +343,26 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
} }
} }
var sm []map[string]any var objs []map[string]any
decoder := json.NewDecoder(strings.NewReader(s)) for offset := 0; offset < len(s); {
for { if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
// incrementally decode the JSON into a list of JSON objects
// skipping over any invalid tokens
if err := decoder.Decode(&sm); err != nil {
if errors.Is(err, io.EOF) {
break break
} } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
if errors.As(err, new(*json.SyntaxError)) { offset += int(syntax.Offset)
r := decoder.Buffered() } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil { // skip over any unmarshalable types
break offset += int(unmarshalType.Offset)
} } else if err != nil {
decoder = json.NewDecoder(r)
continue
}
return nil, false return nil, false
} } else {
// break when an object is decoded
// break as soon as a valid object is decoded
break break
} }
}
var toolCalls []api.ToolCall var toolCalls []api.ToolCall
for _, kv := range sm { for _, kv := range objs {
call := api.ToolCall{ call := api.ToolCall{
ID: uuid.New().String(), ID: uuid.New().String(),
Type: "function", Type: "function",
@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
toolCalls = append(toolCalls, call) toolCalls = append(toolCalls, call)
} }
if len(toolCalls) > 0 { return toolCalls, len(toolCalls) > 0
return toolCalls, true
}
return nil, false
} }

View file

@ -136,11 +136,16 @@ func TestExecuteWithTools(t *testing.T) {
cases := []struct { cases := []struct {
model string model string
output string output string
ok bool
}{ }{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}, The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + ` {"command-r-plus", "Action: ```json" + `
[ [
{ {
@ -158,8 +163,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
} }
} }
] ]
` + "```"}, ` + "```", true},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
} }
var tools []api.Tool var tools []api.Tool
@ -216,10 +223,11 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
t.Run("parse", func(t *testing.T) { t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl} m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output) actual, ok := m.parseToolCalls(tt.output)
if !ok { if ok != tt.ok {
t.Fatal("failed to parse tool calls") t.Fatalf("expected %t, got %t", tt.ok, ok)
} }
if tt.ok {
for i := range actual { for i := range actual {
// ID is randomly generated so clear it for comparison // ID is randomly generated so clear it for comparison
actual[i].ID = "" actual[i].ID = ""
@ -228,6 +236,7 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
if diff := cmp.Diff(actual, calls); diff != "" { if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}
}) })
}) })
} }