tools
This commit is contained in:
parent
9e35d9bbee
commit
d02bbebb11
7 changed files with 263 additions and 53 deletions
39
api/types.go
39
api/types.go
|
@ -97,6 +97,9 @@ type ChatRequest struct {
|
|||
// followin the request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Tools is an optional list of tools the model has access to.
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
@ -105,9 +108,36 @@ type ChatRequest struct {
|
|||
// role ("system", "user", or "assistant"), the content and an optional list
|
||||
// of images.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
|
@ -374,6 +404,9 @@ type GenerateResponse struct {
|
|||
// Response is the textual response itself.
|
||||
Response string `json:"response"`
|
||||
|
||||
// ToolCalls is the list of tools the model wants to call
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
|
|
|
@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion")
|
|||
|
||||
type Capability string
|
||||
|
||||
const CapabilityCompletion = Capability("completion")
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
)
|
||||
|
||||
type registryOptions struct {
|
||||
Insecure bool
|
||||
|
@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
|||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||
errs = append(errs, errCapabilityCompletion)
|
||||
}
|
||||
case CapabilityTools:
|
||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||
errs = append(errs, errors.New("tools"))
|
||||
}
|
||||
default:
|
||||
slog.Error("unknown capability", "capability", cap)
|
||||
return fmt.Errorf("unknown capability: %s", cap)
|
||||
|
@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
|||
}
|
||||
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
|
||||
return fmt.Errorf("does not support %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
105
server/model.go
105
server/model.go
|
@ -4,6 +4,7 @@ import (
|
|||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -11,7 +12,11 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/llm"
|
||||
|
@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) {
|
|||
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
||||
"ToolCalls": {
|
||||
{
|
||||
"Function": map[string]any{
|
||||
"Name": "@@name@@",
|
||||
"Arguments": "@@arguments@@",
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var kv map[string]string
|
||||
// execute the subtree with placeholders to identify the keys
|
||||
if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
for k, v := range kv {
|
||||
switch v {
|
||||
case "@@name@@":
|
||||
name = k
|
||||
case "@@arguments@@":
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
var sm []map[string]any
|
||||
decoder := json.NewDecoder(strings.NewReader(s))
|
||||
for {
|
||||
// incrementally decode the JSON into a list of JSON objects
|
||||
// skipping over any invalid tokens
|
||||
if err := decoder.Decode(&sm); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if errors.As(err, new(*json.SyntaxError)) {
|
||||
r := decoder.Buffered()
|
||||
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
decoder = json.NewDecoder(r)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// break as soon as a valid object is decoded
|
||||
break
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, kv := range sm {
|
||||
call := api.ToolCall{
|
||||
ID: uuid.New().String(),
|
||||
Type: "function",
|
||||
}
|
||||
|
||||
for k, v := range kv {
|
||||
switch k {
|
||||
case name:
|
||||
call.Function.Name = v.(string)
|
||||
case arguments:
|
||||
call.Function.Arguments = v.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, call)
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
return toolCalls, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
|||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
// always include the last message
|
||||
n := len(msgs) - 1
|
||||
|
@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
r.Response = sb.String()
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
r.ToolCalls = toolCalls
|
||||
r.Response = ""
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
if req.Tools != nil {
|
||||
caps = append(caps, CapabilityTools)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
|
@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
||||
}
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var r api.ChatResponse
|
||||
var resp api.ChatResponse
|
||||
var sb strings.Builder
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.ChatResponse:
|
||||
sb.WriteString(t.Message.Content)
|
||||
r = t
|
||||
resp = t
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
|
@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
r.Message.Content = sb.String()
|
||||
c.JSON(http.StatusOK, r)
|
||||
resp.Message.Content = sb.String()
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"sync"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
"time"
|
||||
|
||||
"github.com/agnivade/levenshtein"
|
||||
"github.com/ollama/ollama/api"
|
||||
|
@ -102,8 +103,18 @@ var response = parse.ActionNode{
|
|||
},
|
||||
}
|
||||
|
||||
var funcs = template.FuncMap{
|
||||
"json": func(v any) string {
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
},
|
||||
"now": func() string {
|
||||
return time.Now().Format("2006-01-02 15:04:05")
|
||||
},
|
||||
}
|
||||
|
||||
func Parse(s string) (*Template, error) {
|
||||
tmpl := template.New("").Option("missingkey=zero")
|
||||
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
||||
|
||||
tmpl, err := tmpl.Parse(s)
|
||||
if err != nil {
|
||||
|
@ -127,7 +138,7 @@ func (t *Template) Vars() []string {
|
|||
var vars []string
|
||||
for _, tt := range t.Templates() {
|
||||
for _, n := range tt.Root.Nodes {
|
||||
vars = append(vars, parseNode(n)...)
|
||||
vars = append(vars, Identifiers(n)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,17 +154,65 @@ func (t *Template) Vars() []string {
|
|||
|
||||
type Values struct {
|
||||
Messages []api.Message
|
||||
Tools []api.Tool
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
}
|
||||
|
||||
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||
var walk func(parse.Node) parse.Node
|
||||
walk = func(n parse.Node) parse.Node {
|
||||
if fn(n) {
|
||||
return n
|
||||
}
|
||||
|
||||
switch t := n.(type) {
|
||||
case *parse.ListNode:
|
||||
for _, c := range t.Nodes {
|
||||
if n := walk(c); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
case *parse.BranchNode:
|
||||
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
|
||||
if n != nil {
|
||||
if n := walk(n); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
case *parse.IfNode:
|
||||
return walk(&t.BranchNode)
|
||||
case *parse.WithNode:
|
||||
return walk(&t.BranchNode)
|
||||
case *parse.RangeNode:
|
||||
return walk(&t.BranchNode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if n := walk(t.Tree.Root); n != nil {
|
||||
return (&template.Template{
|
||||
Tree: &parse.Tree{
|
||||
Root: &parse.ListNode{
|
||||
Nodes: []parse.Node{n},
|
||||
},
|
||||
},
|
||||
}).Funcs(funcs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
system, messages := collate(v.Messages)
|
||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
var b bytes.Buffer
|
||||
var prompt, response string
|
||||
for _, m := range messages {
|
||||
execute := func () error {
|
||||
execute := func() error {
|
||||
if err := t.Template.Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
|
@ -198,12 +257,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
|
||||
var cut bool
|
||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||
switch t := n.(type) {
|
||||
case *parse.ActionNode:
|
||||
case *parse.FieldNode:
|
||||
if slices.Contains(t.Ident, "Response") {
|
||||
cut = true
|
||||
}
|
||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||
cut = true
|
||||
}
|
||||
|
||||
return cut
|
||||
|
@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
|||
return strings.Join(system, "\n\n"), collated
|
||||
}
|
||||
|
||||
func parseNode(n parse.Node) []string {
|
||||
// Identifiers walks the node tree returning any identifiers it finds along the way
|
||||
func Identifiers(n parse.Node) []string {
|
||||
switch n := n.(type) {
|
||||
case *parse.ListNode:
|
||||
var names []string
|
||||
for _, n := range n.Nodes {
|
||||
names = append(names, Identifiers(n)...)
|
||||
}
|
||||
|
||||
return names
|
||||
case *parse.TemplateNode:
|
||||
return Identifiers(n.Pipe)
|
||||
case *parse.ActionNode:
|
||||
return parseNode(n.Pipe)
|
||||
return Identifiers(n.Pipe)
|
||||
case *parse.BranchNode:
|
||||
names := Identifiers(n.Pipe)
|
||||
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
|
||||
if n != nil {
|
||||
names = append(names, Identifiers(n)...)
|
||||
}
|
||||
}
|
||||
return names
|
||||
case *parse.IfNode:
|
||||
names := parseNode(n.Pipe)
|
||||
names = append(names, parseNode(n.List)...)
|
||||
if n.ElseList != nil {
|
||||
names = append(names, parseNode(n.ElseList)...)
|
||||
}
|
||||
return names
|
||||
return Identifiers(&n.BranchNode)
|
||||
case *parse.RangeNode:
|
||||
names := parseNode(n.Pipe)
|
||||
names = append(names, parseNode(n.List)...)
|
||||
if n.ElseList != nil {
|
||||
names = append(names, parseNode(n.ElseList)...)
|
||||
}
|
||||
return names
|
||||
return Identifiers(&n.BranchNode)
|
||||
case *parse.WithNode:
|
||||
names := parseNode(n.Pipe)
|
||||
names = append(names, parseNode(n.List)...)
|
||||
if n.ElseList != nil {
|
||||
names = append(names, parseNode(n.ElseList)...)
|
||||
}
|
||||
return names
|
||||
return Identifiers(&n.BranchNode)
|
||||
case *parse.PipeNode:
|
||||
var names []string
|
||||
for _, c := range n.Cmds {
|
||||
for _, a := range c.Args {
|
||||
names = append(names, parseNode(a)...)
|
||||
names = append(names, Identifiers(a)...)
|
||||
}
|
||||
}
|
||||
return names
|
||||
case *parse.ListNode:
|
||||
var names []string
|
||||
for _, n := range n.Nodes {
|
||||
names = append(names, parseNode(n)...)
|
||||
}
|
||||
|
||||
return names
|
||||
case *parse.FieldNode:
|
||||
return n.Ident
|
||||
case *parse.TemplateNode:
|
||||
return parseNode(n.Pipe)
|
||||
case *parse.VariableNode:
|
||||
return n.Ident
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Add table
Reference in a new issue