marshal json automatically for some template values (#5758)
This commit is contained in:
parent
b23424bb3c
commit
b255445557
8 changed files with 72 additions and 52 deletions
73
api/types.go
73
api/types.go
|
@ -101,12 +101,19 @@ type ChatRequest struct {
|
|||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Tools is an optional list of tools the model has access to.
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
||||
func (t Tools) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// Message is a single message in a chat sequence. The message contains the
|
||||
// role ("system", "user", or "assistant"), the content and an optional list
|
||||
// of images.
|
||||
|
@ -117,30 +124,6 @@ type Message struct {
|
|||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
type Alias Message
|
||||
var a Alias
|
||||
|
@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Function ToolCallFunction `json:"function"`
|
||||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCallFunctionArguments map[string]any
|
||||
|
||||
func (t *ToolCallFunctionArguments) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
|
|
|
@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
"Function": map[string]any{
|
||||
"Name": "@@name@@",
|
||||
"Arguments": "@@arguments@@",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||
return nil, false
|
||||
}
|
||||
|
||||
var kv map[string]string
|
||||
var kv map[string]any
|
||||
// execute the subtree with placeholders to identify the keys
|
||||
// trim any commands that might exist in the template
|
||||
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||
|
@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
for k, v := range kv {
|
||||
switch v {
|
||||
case "@@name@@":
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
case "@@arguments@@":
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
|
@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||
|
||||
calls := []api.ToolCall{
|
||||
{
|
||||
Function: function{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: function{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada",
|
||||
},
|
||||
|
|
2
server/testdata/tools/command-r-plus.gotmpl
vendored
2
server/testdata/tools/command-r-plus.gotmpl
vendored
|
@ -46,7 +46,7 @@ Action: ```json
|
|||
{{- range .ToolCalls }}
|
||||
{
|
||||
"tool_name": "{{ .Function.Name }}",
|
||||
"parameters": {{ json .Function.Arguments }}
|
||||
"parameters": {{ .Function.Arguments }}
|
||||
}
|
||||
{{- end }}
|
||||
]```
|
||||
|
|
4
server/testdata/tools/firefunction.gotmpl
vendored
4
server/testdata/tools/firefunction.gotmpl
vendored
|
@ -17,7 +17,7 @@ If you decide to call functions:
|
|||
|
||||
Available functions as JSON spec:
|
||||
{{- if .Tools }}
|
||||
{{ json .Tools }}
|
||||
{{ .Tools }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>
|
||||
|
@ -25,7 +25,7 @@ Available functions as JSON spec:
|
|||
{{- end }}<|end_header_id|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }} functools[
|
||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
|
||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||
{{- end }}]
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{{- range .Tools }} {{ json .Function }}
|
||||
{{- range .Tools }} {{ .Function }}
|
||||
{{- end }} </tools>
|
||||
{{- end }}
|
||||
{{- end }}<|eot_id|>
|
||||
|
@ -20,7 +20,7 @@ Here are the available tools:
|
|||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
|
|
4
server/testdata/tools/mistral.gotmpl
vendored
4
server/testdata/tools/mistral.gotmpl
vendored
|
@ -1,13 +1,13 @@
|
|||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}
|
||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||
|
||||
{{ end }}{{ .Content }}[/INST]
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }} {{ .Content }}</s>
|
||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}]</s>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||
|
|
|
@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
|
|||
|
||||
type Values struct {
|
||||
Messages []api.Message
|
||||
Tools []api.Tool
|
||||
Prompt string
|
||||
Suffix string
|
||||
api.Tools
|
||||
Prompt string
|
||||
Suffix string
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
|
|
Loading…
Add table
Reference in a new issue