From b2554455572b28c0e18423d6fe6896cf7137dbd6 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 17 Jul 2024 15:35:11 -0700 Subject: [PATCH] marshal json automatically for some template values (#5758) --- api/types.go | 73 ++++++++++++------- server/model.go | 18 +++-- server/model_test.go | 13 +--- server/testdata/tools/command-r-plus.gotmpl | 2 +- server/testdata/tools/firefunction.gotmpl | 4 +- .../tools/llama3-groq-tool-use.gotmpl | 4 +- server/testdata/tools/mistral.gotmpl | 4 +- template/template.go | 6 +- 8 files changed, 72 insertions(+), 52 deletions(-) diff --git a/api/types.go b/api/types.go index e687b8a4..c7e9dce3 100644 --- a/api/types.go +++ b/api/types.go @@ -101,12 +101,19 @@ type ChatRequest struct { KeepAlive *Duration `json:"keep_alive,omitempty"` // Tools is an optional list of tools the model has access to. - Tools []Tool `json:"tools,omitempty"` + Tools `json:"tools,omitempty"` // Options lists model-specific options. Options map[string]interface{} `json:"options"` } +type Tools []Tool + +func (t Tools) String() string { + bts, _ := json.Marshal(t) + return string(bts) +} + // Message is a single message in a chat sequence. The message contains the // role ("system", "user", or "assistant"), the content and an optional list // of images. @@ -117,30 +124,6 @@ type Message struct { ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -type ToolCall struct { - Function struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - } `json:"function"` -} - -type Tool struct { - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters struct { - Type string `json:"type"` - Required []string `json:"required"` - Properties map[string]struct { - Type string `json:"type"` - Description string `json:"description"` - Enum []string `json:"enum,omitempty"` - } `json:"properties"` - } `json:"parameters"` - } `json:"function"` -} - func (m *Message) UnmarshalJSON(b []byte) error { type Alias Message var a Alias @@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error { return nil } +type ToolCall struct { + Function ToolCallFunction `json:"function"` +} + +type ToolCallFunction struct { + Name string `json:"name"` + Arguments ToolCallFunctionArguments `json:"arguments"` +} + +type ToolCallFunctionArguments map[string]any + +func (t *ToolCallFunctionArguments) String() string { + bts, _ := json.Marshal(t) + return string(bts) +} + +type Tool struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + } `json:"properties"` + } `json:"parameters"` +} + +func (t *ToolFunction) String() string { + bts, _ := json.Marshal(t) + return string(bts) +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { diff --git a/server/model.go b/server/model.go index e5d6179b..65231ab1 100644 --- a/server/model.go +++ b/server/model.go @@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { } var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]map[string]any{ + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ "ToolCalls": { { - "Function": map[string]any{ - "Name": "@@name@@", - "Arguments": "@@arguments@@", + Function: api.ToolCallFunction{ + Name: "@@name@@", + Arguments: api.ToolCallFunctionArguments{ + "@@argument@@": 1, + }, }, }, }, @@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { return nil, false } - var kv map[string]string + var kv map[string]any // execute the subtree with placeholders to identify the keys // trim any commands that might exist in the template if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { @@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { // find the keys that correspond to the name and arguments fields var name, arguments string for k, v := range kv { - switch v { - case "@@name@@": + switch v.(type) { + case string: name = k - case "@@arguments@@": + case map[string]any: arguments = k } } diff --git a/server/model_test.go b/server/model_test.go index f0382843..7c826b06 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) { } } -type function struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` -} - func readFile(t *testing.T, base, name string) *bytes.Buffer { t.Helper() @@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, calls := []api.ToolCall{ { - Function: function{ + Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: api.ToolCallFunctionArguments{ "format": "fahrenheit", "location": "San Francisco, CA", }, }, }, { - Function: function{ + Function: api.ToolCallFunction{ Name: "get_current_weather", - Arguments: map[string]any{ + Arguments: api.ToolCallFunctionArguments{ "format": "celsius", "location": "Toronto, Canada", }, diff --git a/server/testdata/tools/command-r-plus.gotmpl b/server/testdata/tools/command-r-plus.gotmpl index 088a4f0e..f30124e3 100644 --- a/server/testdata/tools/command-r-plus.gotmpl +++ b/server/testdata/tools/command-r-plus.gotmpl @@ -46,7 +46,7 @@ Action: ```json {{- range .ToolCalls }} { "tool_name": "{{ .Function.Name }}", - "parameters": {{ json .Function.Arguments }} + "parameters": {{ .Function.Arguments }} } {{- end }} ]``` diff --git a/server/testdata/tools/firefunction.gotmpl b/server/testdata/tools/firefunction.gotmpl index bca88b3b..312be205 100644 --- a/server/testdata/tools/firefunction.gotmpl +++ b/server/testdata/tools/firefunction.gotmpl @@ -17,7 +17,7 @@ If you decide to call functions: Available functions as JSON spec: {{- if .Tools }} -{{ json .Tools }} +{{ .Tools }} {{- end }}<|eot_id|> {{- end }} {{- range .Messages }}<|start_header_id|> @@ -25,7 +25,7 @@ Available functions as JSON spec: {{- end }}<|end_header_id|> {{- if .Content }}{{ .Content }} {{- else if .ToolCalls }} functools[ -{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }} +{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }} {{- end }}] {{- end }}<|eot_id|> {{- end }}<|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/server/testdata/tools/llama3-groq-tool-use.gotmpl index e174f8a5..45e9b462 100644 --- a/server/testdata/tools/llama3-groq-tool-use.gotmpl +++ b/server/testdata/tools/llama3-groq-tool-use.gotmpl @@ -9,7 +9,7 @@ Here are the available tools: -{{- range .Tools }} {{ json .Function }} +{{- range .Tools }} {{ .Function }} {{- end }} {{- end }} {{- end }}<|eot_id|> @@ -20,7 +20,7 @@ Here are the available tools: {{- else if eq .Role "assistant" }} {{- if .Content }}{{ .Content }} {{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} {{- end }} {{- end }} diff --git a/server/testdata/tools/mistral.gotmpl b/server/testdata/tools/mistral.gotmpl index a98bc7ad..b08d6c2c 100644 --- a/server/testdata/tools/mistral.gotmpl +++ b/server/testdata/tools/mistral.gotmpl @@ -1,13 +1,13 @@ {{- range $index, $_ := .Messages }} {{- if eq .Role "user" }} -{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS] +{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS] {{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }} {{ end }}{{ .Content }}[/INST] {{- else if eq .Role "assistant" }} {{- if .Content }} {{ .Content }} {{- else if .ToolCalls }}[TOOL_CALLS] [ -{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}} +{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} {{- end }}] {{- end }} {{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS] diff --git a/template/template.go b/template/template.go index 85b4d21a..b5bfb16c 100644 --- a/template/template.go +++ b/template/template.go @@ -150,9 +150,9 @@ func (t *Template) Vars() []string { type Values struct { Messages []api.Message - Tools []api.Tool - Prompt string - Suffix string + api.Tools + Prompt string + Suffix string // forceLegacy is a flag used to test compatibility with legacy templates forceLegacy bool