marshal json automatically for some template values (#5758)
This commit is contained in:
parent
b23424bb3c
commit
b255445557
8 changed files with 72 additions and 52 deletions
73
api/types.go
73
api/types.go
|
@ -101,12 +101,19 @@ type ChatRequest struct {
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Tools is an optional list of tools the model has access to.
|
// Tools is an optional list of tools the model has access to.
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tools []Tool
|
||||||
|
|
||||||
|
func (t Tools) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// Message is a single message in a chat sequence. The message contains the
|
// Message is a single message in a chat sequence. The message contains the
|
||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
|
@ -117,30 +124,6 @@ type Message struct {
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]any `json:"arguments"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
} `json:"parameters"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
type Alias Message
|
type Alias Message
|
||||||
var a Alias
|
var a Alias
|
||||||
|
@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Function ToolCallFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunctionArguments map[string]any
|
||||||
|
|
||||||
|
func (t *ToolCallFunctionArguments) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function ToolFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolFunction) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
|
|
|
@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
{
|
{
|
||||||
"Function": map[string]any{
|
Function: api.ToolCallFunction{
|
||||||
"Name": "@@name@@",
|
Name: "@@name@@",
|
||||||
"Arguments": "@@arguments@@",
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"@@argument@@": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
var kv map[string]string
|
var kv map[string]any
|
||||||
// execute the subtree with placeholders to identify the keys
|
// execute the subtree with placeholders to identify the keys
|
||||||
// trim any commands that might exist in the template
|
// trim any commands that might exist in the template
|
||||||
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
|
@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
// find the keys that correspond to the name and arguments fields
|
// find the keys that correspond to the name and arguments fields
|
||||||
var name, arguments string
|
var name, arguments string
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
switch v {
|
switch v.(type) {
|
||||||
case "@@name@@":
|
case string:
|
||||||
name = k
|
name = k
|
||||||
case "@@arguments@@":
|
case map[string]any:
|
||||||
arguments = k
|
arguments = k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]any `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
||||||
|
|
||||||
calls := []api.ToolCall{
|
calls := []api.ToolCall{
|
||||||
{
|
{
|
||||||
Function: function{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: function{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "Toronto, Canada",
|
"location": "Toronto, Canada",
|
||||||
},
|
},
|
||||||
|
|
2
server/testdata/tools/command-r-plus.gotmpl
vendored
2
server/testdata/tools/command-r-plus.gotmpl
vendored
|
@ -46,7 +46,7 @@ Action: ```json
|
||||||
{{- range .ToolCalls }}
|
{{- range .ToolCalls }}
|
||||||
{
|
{
|
||||||
"tool_name": "{{ .Function.Name }}",
|
"tool_name": "{{ .Function.Name }}",
|
||||||
"parameters": {{ json .Function.Arguments }}
|
"parameters": {{ .Function.Arguments }}
|
||||||
}
|
}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
]```
|
]```
|
||||||
|
|
4
server/testdata/tools/firefunction.gotmpl
vendored
4
server/testdata/tools/firefunction.gotmpl
vendored
|
@ -17,7 +17,7 @@ If you decide to call functions:
|
||||||
|
|
||||||
Available functions as JSON spec:
|
Available functions as JSON spec:
|
||||||
{{- if .Tools }}
|
{{- if .Tools }}
|
||||||
{{ json .Tools }}
|
{{ .Tools }}
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- range .Messages }}<|start_header_id|>
|
{{- range .Messages }}<|start_header_id|>
|
||||||
|
@ -25,7 +25,7 @@ Available functions as JSON spec:
|
||||||
{{- end }}<|end_header_id|>
|
{{- end }}<|end_header_id|>
|
||||||
{{- if .Content }}{{ .Content }}
|
{{- if .Content }}{{ .Content }}
|
||||||
{{- else if .ToolCalls }} functools[
|
{{- else if .ToolCalls }} functools[
|
||||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
|
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||||
{{- end }}]
|
{{- end }}]
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
Here are the available tools:
|
Here are the available tools:
|
||||||
<tools>
|
<tools>
|
||||||
{{- range .Tools }} {{ json .Function }}
|
{{- range .Tools }} {{ .Function }}
|
||||||
{{- end }} </tools>
|
{{- end }} </tools>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
|
@ -20,7 +20,7 @@ Here are the available tools:
|
||||||
{{- else if eq .Role "assistant" }}
|
{{- else if eq .Role "assistant" }}
|
||||||
{{- if .Content }}{{ .Content }}
|
{{- if .Content }}{{ .Content }}
|
||||||
{{- else if .ToolCalls }}<tool_call>
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
</tool_call>
|
</tool_call>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
|
|
4
server/testdata/tools/mistral.gotmpl
vendored
4
server/testdata/tools/mistral.gotmpl
vendored
|
@ -1,13 +1,13 @@
|
||||||
{{- range $index, $_ := .Messages }}
|
{{- range $index, $_ := .Messages }}
|
||||||
{{- if eq .Role "user" }}
|
{{- if eq .Role "user" }}
|
||||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
{{ end }}{{ .Content }}[/INST]
|
{{ end }}{{ .Content }}[/INST]
|
||||||
{{- else if eq .Role "assistant" }}
|
{{- else if eq .Role "assistant" }}
|
||||||
{{- if .Content }} {{ .Content }}</s>
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
{{- end }}]</s>
|
{{- end }}]</s>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
|
|
@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
|
||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
Tools []api.Tool
|
api.Tools
|
||||||
Prompt string
|
Prompt string
|
||||||
Suffix string
|
Suffix string
|
||||||
|
|
||||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||||
forceLegacy bool
|
forceLegacy bool
|
||||||
|
|
Loading…
Reference in a new issue