This commit is contained in:
Michael Yang 2024-06-20 13:45:47 -07:00
parent 9e35d9bbee
commit d02bbebb11
7 changed files with 263 additions and 53 deletions

View file

@ -97,6 +97,9 @@ type ChatRequest struct {
// followin the request. // followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@ -106,8 +109,35 @@ type ChatRequest struct {
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
} }
func (m *Message) UnmarshalJSON(b []byte) error { func (m *Message) UnmarshalJSON(b []byte) error {
@ -374,6 +404,9 @@ type GenerateResponse struct {
// Response is the textual response itself. // Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// ToolCalls is the list of tools the model wants to call
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`

View file

@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion")
type Capability string type Capability string
const CapabilityCompletion = Capability("completion") const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
)
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion) errs = append(errs, errCapabilityCompletion)
} }
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errors.New("tools"))
}
default: default:
slog.Error("unknown capability", "capability", cap) slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap) return fmt.Errorf("unknown capability: %s", cap)
@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
} }
if err := errors.Join(errs...); err != nil { if err := errors.Join(errs...); err != nil {
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...)) return fmt.Errorf("does not support %w", errors.Join(errs...))
} }
return nil return nil

View file

@ -4,6 +4,7 @@ import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -11,7 +12,11 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strings"
"text/template/parse"
"github.com/google/uuid"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]string
// execute the subtree with placeholders to identify the keys
if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
name = k
case "@@arguments@@":
arguments = k
}
}
var sm []map[string]any
decoder := json.NewDecoder(strings.NewReader(s))
for {
// incrementally decode the JSON into a list of JSON objects
// skipping over any invalid tokens
if err := decoder.Decode(&sm); err != nil {
if errors.Is(err, io.EOF) {
break
}
if errors.As(err, new(*json.SyntaxError)) {
r := decoder.Buffered()
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
break
}
decoder = json.NewDecoder(r)
continue
}
return nil, false
}
// break as soon as a valid object is decoded
break
}
var toolCalls []api.ToolCall
for _, kv := range sm {
call := api.ToolCall{
ID: uuid.New().String(),
Type: "function",
}
for k, v := range kv {
switch k {
case name:
call.Function.Name = v.(string)
case arguments:
call.Function.Arguments = v.(map[string]any)
}
}
toolCalls = append(toolCalls, call)
}
if len(toolCalls) > 0 {
return toolCalls, true
}
return nil, false
}

View file

@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message var system []api.Message
// always include the last message // always include the last message
n := len(msgs) - 1 n := len(msgs) - 1
@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
return "", nil, err return "", nil, err
} }
@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
return "", nil, err return "", nil, err
} }

View file

@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs) prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
r.Response = sb.String() r.Response = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r) c.JSON(http.StatusOK, r)
return return
} }
@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if req.Tools != nil {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
} }
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
var r api.ChatResponse var resp api.ChatResponse
var sb strings.Builder var sb strings.Builder
for rr := range ch { for rr := range ch {
switch t := rr.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(t.Message.Content) sb.WriteString(t.Message.Content)
r = t resp = t
case gin.H: case gin.H:
msg, ok := t["error"].(string) msg, ok := t["error"].(string)
if !ok { if !ok {
@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
r.Message.Content = sb.String() resp.Message.Content = sb.String()
c.JSON(http.StatusOK, r) if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
c.JSON(http.StatusOK, resp)
return return
} }

View file

@ -13,6 +13,7 @@ import (
"sync" "sync"
"text/template" "text/template"
"text/template/parse" "text/template/parse"
"time"
"github.com/agnivade/levenshtein" "github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@ -102,8 +103,18 @@ var response = parse.ActionNode{
}, },
} }
var funcs = template.FuncMap{
"json": func(v any) string {
b, _ := json.Marshal(v)
return string(b)
},
"now": func() string {
return time.Now().Format("2006-01-02 15:04:05")
},
}
func Parse(s string) (*Template, error) { func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero") tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s) tmpl, err := tmpl.Parse(s)
if err != nil { if err != nil {
@ -127,7 +138,7 @@ func (t *Template) Vars() []string {
var vars []string var vars []string
for _, tt := range t.Templates() { for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes { for _, n := range tt.Root.Nodes {
vars = append(vars, parseNode(n)...) vars = append(vars, Identifiers(n)...)
} }
} }
@ -143,17 +154,65 @@ func (t *Template) Vars() []string {
type Values struct { type Values struct {
Messages []api.Message Messages []api.Message
Tools []api.Tool
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool
} }
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages) system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": messages, "Messages": messages,
"Tools": v.Tools,
}) })
} }
@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var b bytes.Buffer var b bytes.Buffer
var prompt, response string var prompt, response string
for _, m := range messages { for _, m := range messages {
execute := func () error { execute := func() error {
if err := t.Template.Execute(&b, map[string]any{ if err := t.Template.Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
@ -198,13 +257,9 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var cut bool var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
switch t := n.(type) { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
case *parse.ActionNode:
case *parse.FieldNode:
if slices.Contains(t.Ident, "Response") {
cut = true cut = true
} }
}
return cut return cut
}) })
@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) {
return strings.Join(system, "\n\n"), collated return strings.Join(system, "\n\n"), collated
} }
func parseNode(n parse.Node) []string { // Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
switch n := n.(type) { switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
}
return names
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode: case *parse.ActionNode:
return parseNode(n.Pipe) return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
}
return names
case *parse.IfNode: case *parse.IfNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode: case *parse.RangeNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode: case *parse.WithNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode: case *parse.PipeNode:
var names []string var names []string
for _, c := range n.Cmds { for _, c := range n.Cmds {
for _, a := range c.Args { for _, a := range c.Args {
names = append(names, parseNode(a)...) names = append(names, Identifiers(a)...)
} }
} }
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names return names
case *parse.FieldNode: case *parse.FieldNode:
return n.Ident return n.Ident
case *parse.TemplateNode: case *parse.VariableNode:
return parseNode(n.Pipe) return n.Ident
} }
return nil return nil