diff --git a/api/types.go b/api/types.go
index e687b8a4..c7e9dce3 100644
--- a/api/types.go
+++ b/api/types.go
@@ -101,12 +101,19 @@ type ChatRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
- Tools []Tool `json:"tools,omitempty"`
+ Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
+type Tools []Tool
+
+func (t Tools) String() string {
+ bts, _ := json.Marshal(t)
+ return string(bts)
+}
+
// Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
@@ -117,30 +124,6 @@ type Message struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
-type ToolCall struct {
- Function struct {
- Name string `json:"name"`
- Arguments map[string]any `json:"arguments"`
- } `json:"function"`
-}
-
-type Tool struct {
- Type string `json:"type"`
- Function struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters struct {
- Type string `json:"type"`
- Required []string `json:"required"`
- Properties map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- } `json:"properties"`
- } `json:"parameters"`
- } `json:"function"`
-}
-
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
@@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
return nil
}
+type ToolCall struct {
+ Function ToolCallFunction `json:"function"`
+}
+
+type ToolCallFunction struct {
+ Name string `json:"name"`
+ Arguments ToolCallFunctionArguments `json:"arguments"`
+}
+
+type ToolCallFunctionArguments map[string]any
+
+func (t *ToolCallFunctionArguments) String() string {
+ bts, _ := json.Marshal(t)
+ return string(bts)
+}
+
+type Tool struct {
+ Type string `json:"type"`
+ Function ToolFunction `json:"function"`
+}
+
+type ToolFunction struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters struct {
+ Type string `json:"type"`
+ Required []string `json:"required"`
+ Properties map[string]struct {
+ Type string `json:"type"`
+ Description string `json:"description"`
+ Enum []string `json:"enum,omitempty"`
+ } `json:"properties"`
+ } `json:"parameters"`
+}
+
+func (t *ToolFunction) String() string {
+ bts, _ := json.Marshal(t)
+ return string(bts)
+}
+
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
diff --git a/server/model.go b/server/model.go
index e5d6179b..65231ab1 100644
--- a/server/model.go
+++ b/server/model.go
@@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
}
var b bytes.Buffer
- if err := tmpl.Execute(&b, map[string][]map[string]any{
+ if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
- "Function": map[string]any{
- "Name": "@@name@@",
- "Arguments": "@@arguments@@",
+ Function: api.ToolCallFunction{
+ Name: "@@name@@",
+ Arguments: api.ToolCallFunctionArguments{
+ "@@argument@@": 1,
+ },
},
},
},
@@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false
}
- var kv map[string]string
+ var kv map[string]any
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
- switch v {
- case "@@name@@":
+ switch v.(type) {
+ case string:
name = k
- case "@@arguments@@":
+ case map[string]any:
arguments = k
}
}
diff --git a/server/model_test.go b/server/model_test.go
index f0382843..7c826b06 100644
--- a/server/model_test.go
+++ b/server/model_test.go
@@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
}
}
-type function struct {
- Name string `json:"name"`
- Arguments map[string]any `json:"arguments"`
-}
-
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
@@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
calls := []api.ToolCall{
{
- Function: function{
+ Function: api.ToolCallFunction{
Name: "get_current_weather",
- Arguments: map[string]any{
+ Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
- Function: function{
+ Function: api.ToolCallFunction{
Name: "get_current_weather",
- Arguments: map[string]any{
+ Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
diff --git a/server/testdata/tools/command-r-plus.gotmpl b/server/testdata/tools/command-r-plus.gotmpl
index 088a4f0e..f30124e3 100644
--- a/server/testdata/tools/command-r-plus.gotmpl
+++ b/server/testdata/tools/command-r-plus.gotmpl
@@ -46,7 +46,7 @@ Action: ```json
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
- "parameters": {{ json .Function.Arguments }}
+ "parameters": {{ .Function.Arguments }}
}
{{- end }}
]```
diff --git a/server/testdata/tools/firefunction.gotmpl b/server/testdata/tools/firefunction.gotmpl
index bca88b3b..312be205 100644
--- a/server/testdata/tools/firefunction.gotmpl
+++ b/server/testdata/tools/firefunction.gotmpl
@@ -17,7 +17,7 @@ If you decide to call functions:
Available functions as JSON spec:
{{- if .Tools }}
-{{ json .Tools }}
+{{ .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
@@ -25,7 +25,7 @@ Available functions as JSON spec:
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
-{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
+{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
\ No newline at end of file
diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/server/testdata/tools/llama3-groq-tool-use.gotmpl
index e174f8a5..45e9b462 100644
--- a/server/testdata/tools/llama3-groq-tool-use.gotmpl
+++ b/server/testdata/tools/llama3-groq-tool-use.gotmpl
@@ -9,7 +9,7 @@
Here are the available tools:
-{{- range .Tools }} {{ json .Function }}
+{{- range .Tools }} {{ .Function }}
{{- end }}
{{- end }}
{{- end }}<|eot_id|>
@@ -20,7 +20,7 @@ Here are the available tools:
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
{{- end }}
diff --git a/server/testdata/tools/mistral.gotmpl b/server/testdata/tools/mistral.gotmpl
index a98bc7ad..b08d6c2c 100644
--- a/server/testdata/tools/mistral.gotmpl
+++ b/server/testdata/tools/mistral.gotmpl
@@ -1,13 +1,13 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
-{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
+{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}
{{- else if .ToolCalls }}[TOOL_CALLS] [
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
diff --git a/template/template.go b/template/template.go
index 85b4d21a..b5bfb16c 100644
--- a/template/template.go
+++ b/template/template.go
@@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
- Tools []api.Tool
- Prompt string
- Suffix string
+ api.Tools
+ Prompt string
+ Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool