marshal json automatically for some template values (#5758)

This commit is contained in:
Michael Yang 2024-07-17 15:35:11 -07:00 committed by GitHub
parent b23424bb3c
commit b255445557
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 72 additions and 52 deletions

View file

@ -101,12 +101,19 @@ type ChatRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to. // Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"` Tools `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
type Tools []Tool
func (t Tools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// Message is a single message in a chat sequence. The message contains the // Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
@ -117,30 +124,6 @@ type Message struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error { func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message type Alias Message
var a Alias var a Alias
@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
return nil return nil
} }
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {

View file

@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
} }
var b bytes.Buffer var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{ if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": { "ToolCalls": {
{ {
"Function": map[string]any{ Function: api.ToolCallFunction{
"Name": "@@name@@", Name: "@@name@@",
"Arguments": "@@arguments@@", Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
}, },
}, },
}, },
@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false return nil, false
} }
var kv map[string]string var kv map[string]any
// execute the subtree with placeholders to identify the keys // execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template // trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// find the keys that correspond to the name and arguments fields // find the keys that correspond to the name and arguments fields
var name, arguments string var name, arguments string
for k, v := range kv { for k, v := range kv {
switch v { switch v.(type) {
case "@@name@@": case string:
name = k name = k
case "@@arguments@@": case map[string]any:
arguments = k arguments = k
} }
} }

View file

@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
} }
} }
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer { func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper() t.Helper()
@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
calls := []api.ToolCall{ calls := []api.ToolCall{
{ {
Function: function{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: map[string]any{ Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit", "format": "fahrenheit",
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
}, },
}, },
{ {
Function: function{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: map[string]any{ Arguments: api.ToolCallFunctionArguments{
"format": "celsius", "format": "celsius",
"location": "Toronto, Canada", "location": "Toronto, Canada",
}, },

View file

@ -46,7 +46,7 @@ Action: ```json
{{- range .ToolCalls }} {{- range .ToolCalls }}
{ {
"tool_name": "{{ .Function.Name }}", "tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }} "parameters": {{ .Function.Arguments }}
} }
{{- end }} {{- end }}
]``` ]```

View file

@ -17,7 +17,7 @@ If you decide to call functions:
Available functions as JSON spec: Available functions as JSON spec:
{{- if .Tools }} {{- if .Tools }}
{{ json .Tools }} {{ .Tools }}
{{- end }}<|eot_id|> {{- end }}<|eot_id|>
{{- end }} {{- end }}
{{- range .Messages }}<|start_header_id|> {{- range .Messages }}<|start_header_id|>
@ -25,7 +25,7 @@ Available functions as JSON spec:
{{- end }}<|end_header_id|> {{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }} {{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[ {{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }} {{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}] {{- end }}]
{{- end }}<|eot_id|> {{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|> {{- end }}<|start_header_id|>assistant<|end_header_id|>

View file

@ -9,7 +9,7 @@
Here are the available tools: Here are the available tools:
<tools> <tools>
{{- range .Tools }} {{ json .Function }} {{- range .Tools }} {{ .Function }}
{{- end }} </tools> {{- end }} </tools>
{{- end }} {{- end }}
{{- end }}<|eot_id|> {{- end }}<|eot_id|>
@ -20,7 +20,7 @@ Here are the available tools:
{{- else if eq .Role "assistant" }} {{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }} {{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call> {{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}} {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }} {{- end }}
</tool_call> </tool_call>
{{- end }} {{- end }}

View file

@ -1,13 +1,13 @@
{{- range $index, $_ := .Messages }} {{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }} {{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS] {{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }} {{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST] {{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }} {{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s> {{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [ {{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}} {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s> {{- end }}]</s>
{{- end }} {{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS] {{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]

View file

@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
type Values struct { type Values struct {
Messages []api.Message Messages []api.Message
Tools []api.Tool api.Tools
Prompt string Prompt string
Suffix string Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool