server: add tool parsing support for nemotron-mini (#6849)

This commit is contained in:
Jeffrey Morgan 2024-09-17 18:06:16 -07:00 committed by GitHub
parent 72962c6e08
commit d05da29912
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 144 additions and 39 deletions

View file

@ -272,6 +272,30 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
func parseObjects(s string) []map[string]any {
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil
} else {
offset += int(decoder.InputOffset())
objs = append(objs, obj)
}
}
return objs
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. // parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format // mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
@ -304,16 +328,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false return nil, false
} }
var kv map[string]any templateObjects := parseObjects(b.String())
// execute the subtree with placeholders to identify the keys if len(templateObjects) == 0 {
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false return nil, false
} }
// find the keys that correspond to the name and arguments fields // find the keys that correspond to the name and arguments fields
var name, arguments string var name, arguments string
for k, v := range kv { for k, v := range templateObjects[0] {
switch v.(type) { switch v.(type) {
case string: case string:
name = k name = k
@ -326,23 +348,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false return nil, false
} }
var objs []map[string]any responseObjects := parseObjects(s)
for offset := 0; offset < len(s); { if len(responseObjects) == 0 {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
slog.Error("parseToolCalls", "error", err)
return nil, false return nil, false
} else { }
offset += int(decoder.InputOffset())
// collect all nested objects // collect all nested objects
var collect func(any) []map[string]any var collect func(any) []map[string]any
@ -361,8 +370,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return all return all
} }
objs = append(objs, collect(obj)...)
} var objs []map[string]any
for _, p := range responseObjects {
objs = append(objs, collect(p)...)
} }
var toolCalls []api.ToolCall var toolCalls []api.ToolCall

View file

@ -69,6 +69,7 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true}, </tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
} }
var tools []api.Tool var tools []api.Tool
@ -217,3 +218,45 @@ func TestParseLayerFromCopy(t *testing.T) {
t.Fatalf("got %d != want 5", len(layers)) t.Fatalf("got %d != want 5", len(layers))
} }
} }
func TestParseObjects(t *testing.T) {
tests := []struct {
input string
want []map[string]any
}{
{
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
},
},
{
input: `{"name": "get_current_weather", "arguments": `,
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := parseObjects(tc.input)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

33
server/testdata/tools/nemotron.gotmpl vendored Normal file
View file

@ -0,0 +1,33 @@
{{- if (or .Tools .System) }}<extra_id_0>System
{{ if .System }}{{ .System }}
{{ end }}
{{- if .Tools }}
{{- range .Tools }}<tool> {{ . }} </tool>{{ end }}
{{ end }}
{{- end }}
{{- range $i, $m := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<extra_id_1>User
{{ .Content }}
{{- if $last }}
<extra_id_1>Assistant
{{- end }}
{{ else if eq .Role "tool" }}<extra_id_1>Tool
{{ .Content }}
{{- if $last }}
<extra_id_1>Assistant
{{- end }}
{{ else if eq .Role "assistant" }}<extra_id_1>Assistant
{{- if .ToolCalls }}
{{ range .ToolCalls }}<toolcall> {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} </toolcall> {{ end }}
{{ else }}
{{ .Content }}
{{- if not $last }}
{{ end }}
{{- end }}
{{- end }}
{{- end }}

18
server/testdata/tools/nemotron.out vendored Normal file
View file

@ -0,0 +1,18 @@
<extra_id_0>System
You are a knowledgable assistant. You can answer questions and perform tasks.
<tool> {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} </tool>
<extra_id_1>User
What's the weather like today in Paris?
<extra_id_1>Assistant
<toolcall> {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} </toolcall>
<extra_id_1>Tool
22
<extra_id_1>Assistant
The current temperature in Paris, France is 22 degrees Celsius.
<extra_id_1>User
What's the weather like today in San Francisco and Toronto?
<extra_id_1>Assistant