Merge pull request #5726 from ollama/mxyng/tools-templates
fix unmarshal type errors
This commit is contained in:
commit
a8388beb94
2 changed files with 39 additions and 42 deletions
|
@ -327,7 +327,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
|
|
||||||
var kv map[string]string
|
var kv map[string]string
|
||||||
// execute the subtree with placeholders to identify the keys
|
// execute the subtree with placeholders to identify the keys
|
||||||
if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
|
// trim any commands that might exist in the template
|
||||||
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,35 +343,26 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sm []map[string]any
|
var objs []map[string]any
|
||||||
decoder := json.NewDecoder(strings.NewReader(s))
|
for offset := 0; offset < len(s); {
|
||||||
for {
|
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
|
||||||
// incrementally decode the JSON into a list of JSON objects
|
break
|
||||||
// skipping over any invalid tokens
|
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||||
if err := decoder.Decode(&sm); err != nil {
|
// skip over any syntax errors
|
||||||
if errors.Is(err, io.EOF) {
|
offset += int(syntax.Offset)
|
||||||
break
|
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||||
}
|
// skip over any unmarshalable types
|
||||||
|
offset += int(unmarshalType.Offset)
|
||||||
if errors.As(err, new(*json.SyntaxError)) {
|
} else if err != nil {
|
||||||
r := decoder.Buffered()
|
|
||||||
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder = json.NewDecoder(r)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
|
} else {
|
||||||
|
// break when an object is decoded
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// break as soon as a valid object is decoded
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range sm {
|
for _, kv := range objs {
|
||||||
call := api.ToolCall{
|
call := api.ToolCall{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Type: "function",
|
Type: "function",
|
||||||
|
@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
toolCalls = append(toolCalls, call)
|
toolCalls = append(toolCalls, call)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
return toolCalls, len(toolCalls) > 0
|
||||||
return toolCalls, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,11 +136,16 @@ func TestExecuteWithTools(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
model string
|
model string
|
||||||
output string
|
output string
|
||||||
|
ok bool
|
||||||
}{
|
}{
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`},
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||||
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
{"command-r-plus", "Action: ```json" + `
|
{"command-r-plus", "Action: ```json" + `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
@ -158,8 +163,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
` + "```"},
|
` + "```", true},
|
||||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`},
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools []api.Tool
|
var tools []api.Tool
|
||||||
|
@ -216,17 +223,19 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}
|
||||||
t.Run("parse", func(t *testing.T) {
|
t.Run("parse", func(t *testing.T) {
|
||||||
m := &Model{Template: tmpl}
|
m := &Model{Template: tmpl}
|
||||||
actual, ok := m.parseToolCalls(tt.output)
|
actual, ok := m.parseToolCalls(tt.output)
|
||||||
if !ok {
|
if ok != tt.ok {
|
||||||
t.Fatal("failed to parse tool calls")
|
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range actual {
|
if tt.ok {
|
||||||
// ID is randomly generated so clear it for comparison
|
for i := range actual {
|
||||||
actual[i].ID = ""
|
// ID is randomly generated so clear it for comparison
|
||||||
}
|
actual[i].ID = ""
|
||||||
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(actual, calls); diff != "" {
|
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in a new issue