marshal json automatically for some template values (#5758)

This commit is contained in:
Michael Yang 2024-07-17 15:35:11 -07:00 committed by GitHub
parent b23424bb3c
commit b255445557
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 72 additions and 52 deletions

View file

@ -101,12 +101,19 @@ type ChatRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"`
Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
type Tools []Tool
func (t Tools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
@ -117,30 +124,6 @@ type Message struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
return nil
}
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {

View file

@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false
}
var kv map[string]string
var kv map[string]any
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
switch v.(type) {
case string:
name = k
case "@@arguments@@":
case map[string]any:
arguments = k
}
}

View file

@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
}
}
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
calls := []api.ToolCall{
{
Function: function{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: function{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},

View file

@ -46,7 +46,7 @@ Action: ```json
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }}
"parameters": {{ .Function.Arguments }}
}
{{- end }}
]```

View file

@ -17,7 +17,7 @@ If you decide to call functions:
Available functions as JSON spec:
{{- if .Tools }}
{{ json .Tools }}
{{ .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
@ -25,7 +25,7 @@ Available functions as JSON spec:
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

View file

@ -9,7 +9,7 @@
Here are the available tools:
<tools>
{{- range .Tools }} {{ json .Function }}
{{- range .Tools }} {{ .Function }}
{{- end }} </tools>
{{- end }}
{{- end }}<|eot_id|>
@ -20,7 +20,7 @@ Here are the available tools:
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
</tool_call>
{{- end }}

View file

@ -1,13 +1,13 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]

View file

@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
Tools []api.Tool
Prompt string
Suffix string
api.Tools
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool