From 269ed6e6a2cea822ab137d40d5c70c8bf09470f8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 17 Jun 2024 10:38:55 -0700 Subject: [PATCH 1/4] update message processing --- server/images.go | 17 +- server/prompt.go | 241 ++++-------------- server/prompt_test.go | 317 ++++++++++++------------ server/routes.go | 508 ++++++++++++-------------------------- template/template.go | 169 ++++++++++++- template/template_test.go | 153 +++++++++++- 6 files changed, 685 insertions(+), 720 deletions(-) diff --git a/server/images.go b/server/images.go index a62991f1..688d5dca 100644 --- a/server/images.go +++ b/server/images.go @@ -34,6 +34,8 @@ import ( "github.com/ollama/ollama/version" ) +var errCapabilityCompletion = errors.New("completion") + type Capability string const CapabilityCompletion = Capability("completion") @@ -62,7 +64,10 @@ type Model struct { Template *template.Template } -func (m *Model) Has(caps ...Capability) bool { +// CheckCapabilities checks if the model has the specified capabilities returning an error describing +// any missing or unknown capabilities +func (m *Model) CheckCapabilities(caps ...Capability) error { + var errs []error for _, cap := range caps { switch cap { case CapabilityCompletion: @@ -81,15 +86,19 @@ func (m *Model) Has(caps ...Capability) bool { } if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { - return false + errs = append(errs, errCapabilityCompletion) } default: slog.Error("unknown capability", "capability", cap) - return false + return fmt.Errorf("unknown capability: %s", cap) } } - return true + if err := errors.Join(errs...); err != nil { + return fmt.Errorf("missing capabilities: %w", errors.Join(errs...)) + } + + return nil } func (m *Model) String() string { diff --git a/server/prompt.go b/server/prompt.go index bfc319a5..5016fbe1 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -1,217 +1,74 @@ package server import ( - "fmt" + "bytes" + "context" "log/slog" - "strings" - - "text/template/parse" + "slices" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/template" ) -// isResponseNode checks if the node contains .Response -func isResponseNode(node *parse.ActionNode) bool { - for _, cmd := range node.Pipe.Cmds { - for _, arg := range cmd.Args { - if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 { - if fieldNode.Ident[0] == "Response" { - return true - } - } +func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { + // extract system messages which should always be included + var system []api.Message + msgs = slices.DeleteFunc(msgs, func(m api.Message) bool { + if m.Role == "system" { + system = append(system, m) + return true } - } - return false -} -// formatTemplateForResponse formats the template AST to: -// 1. remove all nodes after the first .Response (if generate=true) -// 2. add a .Response node to the end if it doesn't exist -// TODO(jmorganca): this should recursively cut the template before the first .Response -func formatTemplateForResponse(tmpl *template.Template, generate bool) { - var found bool - for i, node := range tmpl.Tree.Root.Nodes { - if actionNode, ok := node.(*parse.ActionNode); ok { - if isResponseNode(actionNode) { - found = true - if generate { - tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1] - break - } - } + return false + }) + + if len(system) == 0 && r.model.System != "" { + // add model system prompt since it wasn't provided + system = append(system, api.Message{Role: "system", Content: r.model.System}) + } + + n := len(msgs) - 1 + for i := n - 1; i >= 0; i-- { + var b bytes.Buffer + if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { + return "", nil, err } - } - if !found { - // add the response node if it doesn't exist - responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}} - responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}} - responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode} - tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode) - } -} - -// Prompt renders a prompt from a template. If generate is set to true, -// the response and parts of the template following it are not rendered -func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) { - formatTemplateForResponse(tmpl, generate) - - vars := map[string]any{ - "System": system, - "Prompt": prompt, - "Response": response, - } - - var sb strings.Builder - if err := tmpl.Execute(&sb, vars); err != nil { - return "", err - } - - return sb.String(), nil -} - -func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { - rendered, err := Prompt(tmpl, system, prompt, response, false) - if err != nil { - return 0, err - } - - tokens, err := encode(rendered) - if err != nil { - slog.Error("failed to encode prompt", "err", err) - return 0, err - } - - return len(tokens), err -} - -// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size -func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { - type prompt struct { - System string - Prompt string - Response string - - images []int - tokens int - } - - var p prompt - - // iterate through messages to build up {system,user,response} prompts - var imgId int - var prompts []prompt - for _, msg := range messages { - switch strings.ToLower(msg.Role) { - case "system": - if p.System != "" || p.Prompt != "" || p.Response != "" { - prompts = append(prompts, p) - p = prompt{} - } - - p.System = msg.Content - case "user": - if p.Prompt != "" || p.Response != "" { - prompts = append(prompts, p) - p = prompt{} - } - - var sb strings.Builder - for range msg.Images { - fmt.Fprintf(&sb, "[img-%d] ", imgId) - p.images = append(p.images, imgId) - imgId += 1 - } - - sb.WriteString(msg.Content) - p.Prompt = sb.String() - case "assistant": - if p.Response != "" { - prompts = append(prompts, p) - p = prompt{} - } - - p.Response = msg.Content - default: - return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) - } - } - - // add final prompt - if p.System != "" || p.Prompt != "" || p.Response != "" { - prompts = append(prompts, p) - } - - // calculate token lengths for each prompt, estimating 768 tokens per images - for i, p := range prompts { - tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode) + s, err := r.llama.Tokenize(ctx, b.String()) if err != nil { - return "", err + return "", nil, err } - prompts[i].tokens = tokens + len(prompts[i].images)*768 - } - - // truncate images and prompts starting from the beginning of the list - // until either one prompt remains or the total tokens fits the context window - // TODO (jmorganca): this doesn't account for the context window room required for the response - for { - var required int - for _, p := range prompts { - required += p.tokens + c := len(s) + if r.model.ProjectorPaths != nil { + for _, m := range msgs[i:] { + // TODO: get image embedding length from project metadata + c += 768 * len(m.Images) + } } - required += 1 // for bos token - - if required <= window { - slog.Debug("prompt now fits in context window", "required", required, "window", window) + if c > r.NumCtx { + slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break + } else { + n = i } - - prompt := &prompts[0] - - if len(prompt.images) > 1 { - img := prompt.images[0] - slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window) - prompt.images = prompt.images[1:] - prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1) - prompt.tokens -= 768 - continue - } - - if len(prompts) > 1 { - slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window) - system := prompt.System - prompts = prompts[1:] - - if system != "" && prompts[0].System == "" { - prompts[0].System = system - - tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode) - if err != nil { - return "", err - } - - prompts[0].tokens = tokens + len(prompts[0].images)*768 - } - - continue - } - - // stop truncating if there's only one prompt left - break } - var sb strings.Builder - for i, p := range prompts { - // last prompt should leave the response unrendered (for completion) - rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) - if err != nil { - return "", err - } - sb.WriteString(rendered) + var b bytes.Buffer + if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { + return "", nil, err } - return sb.String(), nil + for _, m := range msgs[n:] { + for _, i := range m.Images { + images = append(images, llm.ImageData{ + ID: len(images), + Data: i, + }) + } + } + + return b.String(), images, nil } diff --git a/server/prompt_test.go b/server/prompt_test.go index 7df58d0b..59288b46 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -1,215 +1,214 @@ package server import ( + "bytes" + "context" "strings" "testing" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/template" ) -func TestPrompt(t *testing.T) { - tests := []struct { - name string - template string - system string - prompt string - response string - generate bool - want string - }{ - { - name: "simple prompt", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - system: "You are a Wizard.", - prompt: "What are the potion ingredients?", - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "implicit response", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - system: "You are a Wizard.", - prompt: "What are the potion ingredients?", - response: "I don't know.", - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.", - }, - { - name: "response", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - system: "You are a Wizard.", - prompt: "What are the potion ingredients?", - response: "I don't know.", - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", - }, - { - name: "cut", - template: "{{ .System }}{{ .Prompt }}{{ .Response }}", - system: "You are a Wizard.", - prompt: "What are the potion ingredients?", - response: "I don't know.", - generate: true, - want: "You are a Wizard.What are the potion ingredients?I don't know.", - }, - { - name: "nocut", - template: "{{ .System }}{{ .Prompt }}{{ .Response }}", - system: "You are a Wizard.", - prompt: "What are the potion ingredients?", - response: "I don't know.", - want: "You are a Wizard.What are the potion ingredients?I don't know.", - }, +type mock struct { + llm.LlamaServer +} + +func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) { + for range strings.Fields(s) { + tokens = append(tokens, len(tokens)) } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - tmpl, err := template.Parse(tc.template) - if err != nil { - t.Fatal(err) - } - - got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate) - if err != nil { - t.Errorf("error = %v", err) - } - - if got != tc.want { - t.Errorf("got = %v, want %v", got, tc.want) - } - }) - } + return } func TestChatPrompt(t *testing.T) { - tests := []struct { - name string - template string - messages []api.Message - window int - want string + type expect struct { + prompt string + images [][]byte + } + + cases := []struct { + name string + limit int + msgs []api.Message + expect }{ { - name: "simple prompt", - template: "[INST] {{ .Prompt }} [/INST]", - messages: []api.Message{ - {Role: "user", Content: "Hello"}, + name: "messages", + limit: 64, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", }, - window: 1024, - want: "[INST] Hello [/INST]", }, { - name: "with system message", - template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "Hello"}, + name: "truncate messages", + limit: 1, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "A test. And a thumping good one at that, I'd wager. ", }, - window: 1024, - want: "[INST] <>You are a Wizard.<> Hello [/INST]", }, { - name: "with response", - template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: "I am?"}, + name: "truncate messages with image", + limit: 64, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}}, + }, + expect: expect{ + prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{ + []byte("something"), + }, }, - window: 1024, - want: "[INST] <>You are a Wizard.<> Hello [/INST] I am?", }, { - name: "with implicit response", - template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: "I am?"}, + name: "truncate messages with images", + limit: 64, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}}, + }, + expect: expect{ + prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{ + []byte("somethingelse"), + }, }, - window: 1024, - want: "[INST] <>You are a Wizard.<> Hello [/INST]I am?", }, { - name: "with conversation", - template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "What are the potion ingredients?"}, - {Role: "assistant", Content: "sugar"}, - {Role: "user", Content: "Anything else?"}, + name: "messages with images", + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}}, + }, + expect: expect{ + prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{ + []byte("something"), + []byte("somethingelse"), + }, }, - window: 1024, - want: "[INST] <>You are a Wizard.<> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ", }, { - name: "with truncation", - template: "{{ .System }} {{ .Prompt }} {{ .Response }} ", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "Hello"}, - {Role: "assistant", Content: "I am?"}, - {Role: "user", Content: "Why is the sky blue?"}, - {Role: "assistant", Content: "The sky is blue from rayleigh scattering"}, + name: "message with image tag", + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}}, + }, + expect: expect{ + prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{ + []byte("something"), + []byte("somethingelse"), + }, }, - window: 10, - want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering", }, { - name: "images", - template: "{{ .System }} {{ .Prompt }}", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, + name: "messages with interleaved images", + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "user", Images: []api.ImageData{[]byte("something")}}, + {Role: "user", Images: []api.ImageData{[]byte("somethingelse")}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{ + []byte("something"), + []byte("somethingelse"), + }, }, - window: 1024, - want: "You are a Wizard. [img-0] Hello", }, { - name: "images truncated", - template: "{{ .System }} {{ .Prompt }}", - messages: []api.Message{ - {Role: "system", Content: "You are a Wizard."}, - {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, + name: "truncate message with interleaved images", + limit: 1024, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "user", Images: []api.ImageData{[]byte("something")}}, + {Role: "user", Images: []api.ImageData{[]byte("somethingelse")}}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ", + images: [][]byte{ + []byte("somethingelse"), + }, }, - window: 1024, - want: "You are a Wizard. [img-0] [img-1] Hello", }, { - name: "empty list", - template: "{{ .System }} {{ .Prompt }}", - messages: []api.Message{}, - window: 1024, - want: "", - }, - { - name: "empty prompt", - template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", - messages: []api.Message{ - {Role: "user", Content: ""}, + name: "message with system prompt", + limit: 2048, + msgs: []api.Message{ + {Role: "system", Content: "You are the Test Who Lived."}, + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ", }, - window: 1024, - want: "", }, } - encode := func(s string) ([]int, error) { - words := strings.Fields(s) - return make([]int, len(words)), nil + tmpl, err := template.Parse(` +{{- if .System }}{{ .System }} {{ end }} +{{- if .Prompt }}{{ .Prompt }} {{ end }} +{{- if .Response }}{{ .Response }} {{ end }}`) + if err != nil { + t.Fatal(err) } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - tmpl, err := template.Parse(tc.template) + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + r := runnerRef{ + llama: mock{}, + model: &Model{Template: tmpl, ProjectorPaths: []string{"vision"}}, + Options: &api.Options{}, + } + + r.NumCtx = tt.limit + prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs) if err != nil { t.Fatal(err) } - got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode) - if err != nil { - t.Errorf("error = %v", err) + if tt.prompt != prompt { + t.Errorf("expected %q, got %q", tt.prompt, prompt) } - if got != tc.want { - t.Errorf("got: %q, want: %q", got, tc.want) + if len(images) != len(tt.images) { + t.Fatalf("expected %d images, got %d", len(tt.images), len(images)) + } + + for i := range images { + if images[i].ID != i { + t.Errorf("expected ID %d, got %d", i, images[i].ID) + } + + if !bytes.Equal(images[i].Data, tt.images[i]) { + t.Errorf("expected %q, got %q", tt.images[i], images[i]) + } } }) } diff --git a/server/routes.go b/server/routes.go index ac6b713a..35e64511 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,13 +1,13 @@ package server import ( + "bytes" "cmp" "context" "encoding/json" "errors" "fmt" "io" - "io/fs" "log/slog" "net" "net/http" @@ -67,163 +67,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options return opts, nil } -func isSupportedImageType(image []byte) bool { - contentType := http.DetectContentType(image) - allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} - return slices.Contains(allowedTypes, contentType) +func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { + if name == "" { + return nil, errors.New("model is required") + } + + model, err := GetModel(name) + if err != nil { + return nil, err + } + + if err := model.CheckCapabilities(caps...); err != nil { + return nil, fmt.Errorf("%s %w", name, err) + } + + opts, err := modelOptions(model, requestOpts) + if err != nil { + return nil, err + } + + runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) + var runner *runnerRef + select { + case runner = <-runnerCh: + case err = <-errCh: + return nil, err + } + + return runner, nil } func (s *Server) GenerateHandler(c *gin.Context) { - checkpointStart := time.Now() var req api.GenerateRequest - err := c.ShouldBindJSON(&req) - - switch { - case errors.Is(err, io.EOF): + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // validate the request - switch { - case req.Model == "": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + if req.Format != "" && req.Format != "json" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""}) return - case len(req.Format) > 0 && req.Format != "json": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) - return - case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0): + } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) return } - for _, img := range req.Images { - if !isSupportedImageType(img) { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) - return - } - } - - model, err := GetModel(req.Model) - if err != nil { - var pErr *fs.PathError - if errors.As(err, &pErr) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + caps := []Capability{CapabilityCompletion} + r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + if errors.Is(err, errCapabilityCompletion) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) + return + } else if err != nil { + handleScheduleError(c, err) return } - if !model.Has(CapabilityCompletion) { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)}) - return + images := make([]llm.ImageData, len(req.Images)) + for i := range req.Images { + images[i] = llm.ImageData{ID: i, Data: req.Images[i]} } - opts, err := modelOptions(model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) - var runner *runnerRef - select { - case runner = <-rCh: - case err = <-eCh: - handleErrorResponse(c, err) - return - } - - // an empty request loads the model - // note: for a short while template was used in lieu - // of `raw` mode so we need to check for it too - if req.Prompt == "" && req.Template == "" && req.System == "" { - c.JSON(http.StatusOK, api.GenerateResponse{ - CreatedAt: time.Now().UTC(), - Model: req.Model, - Done: true, - DoneReason: "load", - }) - return - } - - tmpl, err := template.Parse(req.Template) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - checkpointLoaded := time.Now() - - var prompt string - switch { - case req.Raw: - prompt = req.Prompt - case req.Prompt != "": - if req.Template == "" { - tmpl = model.Template + prompt := req.Prompt + if !req.Raw { + var msgs []api.Message + if req.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: req.System}) + } else if r.model.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) } - if req.System == "" { - req.System = model.System + if req.Prompt != "" { + for _, i := range images { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + } + + msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) } - slog.Debug("generate handler", "prompt", req.Prompt) - slog.Debug("generate handler", "template", req.Template) - slog.Debug("generate handler", "system", req.System) - - var sb strings.Builder - for i := range req.Images { - fmt.Fprintf(&sb, "[img-%d] ", i) - } - - sb.WriteString(req.Prompt) - - p, err := Prompt(tmpl, req.System, sb.String(), "", true) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + if len(msgs) == 0 { + c.JSON(http.StatusOK, api.GenerateResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Done: true, + DoneReason: "load", + }) return } - sb.Reset() + tmpl := r.model.Template + if req.Template != "" { + tmpl, err = template.Parse(req.Template) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + var b bytes.Buffer if req.Context != nil { - prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) + s, err := r.llama.Detokenize(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - sb.WriteString(prev) + b.WriteString(s) } - sb.WriteString(p) + if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } - prompt = sb.String() + prompt = b.String() } - slog.Debug("generate handler", "prompt", prompt) + slog.Debug("generate request", "prompt", prompt, "images", images) ch := make(chan any) - var generated strings.Builder go func() { defer close(ch) - - fn := func(r llm.CompletionResponse) { - // Build up the full response - if _, err := generated.WriteString(r.Content); err != nil { - ch <- gin.H{"error": err.Error()} - return - } - - resp := api.GenerateResponse{ + if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + Prompt: prompt, + Images: images, + Format: req.Format, + Options: *r.Options, + }, func(r llm.CompletionResponse) { + ch <- api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Done: r.Done, Response: r.Content, + Done: r.Done, DoneReason: r.DoneReason, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -232,77 +209,35 @@ func (s *Server) GenerateHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } - - if r.Done { - resp.TotalDuration = time.Since(checkpointStart) - resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) - - if !req.Raw { - p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // TODO (jmorganca): encode() should not strip special tokens - tokens, err := runner.llama.Tokenize(c.Request.Context(), p) - if err != nil { - ch <- gin.H{"error": err.Error()} - return - } - - resp.Context = append(req.Context, tokens...) - } - } - - ch <- resp - } - - var images []llm.ImageData - for i := range req.Images { - images = append(images, llm.ImageData{ - ID: i, - Data: req.Images[i], - }) - } - - // Start prediction - req := llm.CompletionRequest{ - Prompt: prompt, - Format: req.Format, - Images: images, - Options: opts, - } - if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil { + }); err != nil { ch <- gin.H{"error": err.Error()} } }() if req.Stream != nil && !*req.Stream { - // Accumulate responses into the final response - var final api.GenerateResponse + var r api.GenerateResponse var sb strings.Builder - for resp := range ch { - switch r := resp.(type) { + for rr := range ch { + switch t := rr.(type) { case api.GenerateResponse: - sb.WriteString(r.Response) - final = r + sb.WriteString(t.Response) + r = t case gin.H: - if errorMsg, ok := r["error"].(string); ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) - return - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"}) - return + msg, ok := t["error"].(string) + if !ok { + msg = "unexpected error format in response" } + + c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + return default: - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) return } } - final.Response = sb.String() - c.JSON(http.StatusOK, final) + r.Response = sb.String() + c.JSON(http.StatusOK, r) return } @@ -311,44 +246,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Model == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - - model, err := GetModel(req.Model) + r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { - var pErr *fs.PathError - if errors.As(err, &pErr) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - opts, err := modelOptions(model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) - var runner *runnerRef - select { - case runner = <-rCh: - case err = <-eCh: - handleErrorResponse(c, err) + handleScheduleError(c, err) return } @@ -358,17 +266,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) + embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) return } - resp := api.EmbeddingResponse{ - Embedding: embedding, - } - c.JSON(http.StatusOK, resp) + c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding}) } func (s *Server) PullModelHandler(c *gin.Context) { @@ -649,9 +554,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } } - msgs := make([]api.Message, 0) - for _, msg := range m.Messages { - msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content}) + msgs := make([]api.Message, len(m.Messages)) + for i, msg := range m.Messages { + msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} } n := model.ParseName(req.Model) @@ -1214,132 +1119,55 @@ func (s *Server) ProcessHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } -// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model -func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) { - encode := func(s string) ([]int, error) { - return runner.llama.Tokenize(ctx, s) - } - - prompt, err := ChatPrompt(template, messages, numCtx, encode) - if err != nil { - return "", err - } - - return prompt, nil -} - func (s *Server) ChatHandler(c *gin.Context) { - checkpointStart := time.Now() - var req api.ChatRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // validate the request - switch { - case req.Model == "": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + caps := []Capability{CapabilityCompletion} + r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + if errors.Is(err, errCapabilityCompletion) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return - case len(req.Format) > 0 && req.Format != "json": - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) + } else if err != nil { + handleScheduleError(c, err) return } - model, err := GetModel(req.Model) - if err != nil { - var pErr *fs.PathError - if errors.As(err, &pErr) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - if !model.Has(CapabilityCompletion) { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)}) - return - } - - opts, err := modelOptions(model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) - var runner *runnerRef - select { - case runner = <-rCh: - case err = <-eCh: - handleErrorResponse(c, err) - return - } - - checkpointLoaded := time.Now() - - // if the first message is not a system message, then add the model's default system message - if len(req.Messages) > 0 && req.Messages[0].Role != "system" { - req.Messages = append([]api.Message{ - { - Role: "system", - Content: model.System, - }, - }, req.Messages...) - } - - prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - // an empty request loads the model - if len(req.Messages) == 0 || prompt == "" { - resp := api.ChatResponse{ - CreatedAt: time.Now().UTC(), + if len(req.Messages) == 0 { + c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant"}, Done: true, DoneReason: "load", - Message: api.Message{Role: "assistant"}, - } - c.JSON(http.StatusOK, resp) + }) return } - // only send images that are in the prompt - var i int - var images []llm.ImageData - for _, m := range req.Messages { - for _, img := range m.Images { - if !isSupportedImageType(img) { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) - return - } - - if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) { - images = append(images, llm.ImageData{Data: img, ID: i}) - } - i += 1 - } + prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - slog.Debug("chat handler", "prompt", prompt, "images", len(images)) + slog.Debug("chat request", "images", len(images), "prompt", prompt) ch := make(chan any) - go func() { defer close(ch) - - fn := func(r llm.CompletionResponse) { - resp := api.ChatResponse{ + if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + Prompt: prompt, + Images: images, + Format: req.Format, + Options: *r.Options, + }, func(r llm.CompletionResponse) { + ch <- api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Message: api.Message{Role: "assistant", Content: r.Content}, @@ -1352,64 +1180,48 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } - - if r.Done { - resp.TotalDuration = time.Since(checkpointStart) - resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) - } - - ch <- resp - } - - if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Format: req.Format, - Images: images, - Options: opts, - }, fn); err != nil { + }); err != nil { ch <- gin.H{"error": err.Error()} } }() if req.Stream != nil && !*req.Stream { - // Accumulate responses into the final response - var final api.ChatResponse + var r api.ChatResponse var sb strings.Builder - for resp := range ch { - switch r := resp.(type) { + for rr := range ch { + switch t := rr.(type) { case api.ChatResponse: - sb.WriteString(r.Message.Content) - final = r + sb.WriteString(t.Message.Content) + r = t case gin.H: - if errorMsg, ok := r["error"].(string); ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) - return - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"}) - return + msg, ok := t["error"].(string) + if !ok { + msg = "unexpected error format in response" } + + c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + return default: - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) return } } - final.Message = api.Message{Role: "assistant", Content: sb.String()} - c.JSON(http.StatusOK, final) + r.Message.Content = sb.String() + c.JSON(http.StatusOK, r) return } streamResponse(c, ch) } -func handleErrorResponse(c *gin.Context, err error) { - if errors.Is(err, context.Canceled) { +func handleScheduleError(c *gin.Context, err error) { + switch { + case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) - return - } - if errors.Is(err, ErrMaxQueue) { + case errors.Is(err, ErrMaxQueue): c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) - return + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } diff --git a/template/template.go b/template/template.go index d15f7156..cfba5a23 100644 --- a/template/template.go +++ b/template/template.go @@ -5,6 +5,7 @@ import ( "embed" "encoding/json" "errors" + "fmt" "io" "math" "slices" @@ -14,6 +15,7 @@ import ( "text/template/parse" "github.com/agnivade/levenshtein" + "github.com/ollama/ollama/api" "golang.org/x/exp/maps" ) @@ -74,30 +76,78 @@ func Named(s string) (*named, error) { return nil, errors.New("no matching template found") } +var DefaultTemplate, _ = Parse("{{ .Prompt }}") + type Template struct { *template.Template raw string } +var response = parse.ActionNode{ + NodeType: parse.NodeAction, + Pipe: &parse.PipeNode{ + NodeType: parse.NodePipe, + Cmds: []*parse.CommandNode{ + { + NodeType: parse.NodeCommand, + Args: []parse.Node{ + &parse.FieldNode{ + NodeType: parse.NodeField, + Ident: []string{"Response"}, + }, + }, + }, + }, + }, +} + +func Parse(s string) (*Template, error) { + tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{ + "toJson": func(v any) string { + b, err := json.Marshal(v) + if err != nil { + return "" + } + + return string(b) + }, + "isLastMessage": func(s []*api.Message, m *api.Message) bool { + for i := len(s) - 1; i >= 0; i-- { + if m.Role != s[i].Role { + continue + } + + return m == s[i] + } + + return false + }, + }) + + tmpl, err := tmpl.Parse(s) + if err != nil { + return nil, err + } + + t := Template{Template: tmpl, raw: s} + if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") { + // touch up the template and append {{ .Response }} + tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response) + } + + return &t, nil +} + func (t *Template) String() string { return t.raw } -var DefaultTemplate, _ = Parse("{{ .Prompt }}") - -func Parse(s string) (*Template, error) { - t, err := template.New("").Option("missingkey=zero").Parse(s) - if err != nil { - return nil, err - } - - return &Template{Template: t, raw: s}, nil -} - func (t *Template) Vars() []string { var vars []string - for _, n := range t.Tree.Root.Nodes { - vars = append(vars, parseNode(n)...) + for _, tt := range t.Templates() { + for _, n := range tt.Root.Nodes { + vars = append(vars, parseNode(n)...) + } } set := make(map[string]struct{}) @@ -110,6 +160,97 @@ func (t *Template) Vars() []string { return vars } +type Values struct { + Messages []api.Message +} + +func (t *Template) Execute(w io.Writer, v Values) error { + system, collated := collate(v.Messages) + if slices.Contains(t.Vars(), "messages") { + return t.Template.Execute(w, map[string]any{ + "System": system, + "Messages": collated, + }) + } + + var b bytes.Buffer + var prompt, response string + for i, m := range collated { + if m.Role == "user" { + prompt = m.Content + } else { + response = m.Content + } + + if i != len(collated)-1 && prompt != "" && response != "" { + if err := t.Template.Execute(&b, map[string]any{ + "System": "", + "Prompt": prompt, + "Response": response, + }); err != nil { + return err + } + + prompt = "" + response = "" + } + } + + var cut bool + tree := t.Template.Copy() + // for the last message, cut everything after "{{ .Response }}" + tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool { + if slices.Contains(parseNode(n), "Response") { + cut = true + } + + return cut + }) + + if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{ + "System": system, + "Prompt": prompt, + }); err != nil { + return err + } + + _, err := io.Copy(w, &b) + return err +} + +func collate(msgs []api.Message) (system string, collated []*api.Message) { + var n int + for i := range msgs { + msg := msgs[i] + if msg.Role == "system" { + if system != "" { + system += "\n\n" + } + + system += msg.Content + continue + } + + for range msg.Images { + imageTag := fmt.Sprintf("[img-%d]", n) + if !strings.Contains(msg.Content, "[img]") { + msg.Content = strings.TrimSpace("[img] " + msg.Content) + } + + msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1) + n++ + } + + if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role { + collated[len(collated)-1].Content += "\n\n" + msg.Content + } else { + collated = append(collated, &msg) + } + } + + return +} + func parseNode(n parse.Node) []string { switch n := n.(type) { case *parse.ActionNode: @@ -152,6 +293,8 @@ func parseNode(n parse.Node) []string { return names case *parse.FieldNode: return n.Ident + case *parse.TemplateNode: + return parseNode(n.Pipe) } return nil diff --git a/template/template_test.go b/template/template_test.go index eda4634f..5d5dad4b 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -11,6 +11,7 @@ import ( "testing" "text/template" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" ) @@ -64,13 +65,12 @@ func TestParse(t *testing.T) { template string vars []string }{ - {"{{ .Prompt }}", []string{"prompt"}}, - {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, + {"{{ .Prompt }}", []string{"prompt", "response"}}, + {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}}, {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, - {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}}, + {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}}, {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, - {"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}}, } for _, tt := range cases { @@ -87,3 +87,148 @@ func TestParse(t *testing.T) { }) } } + +func TestExecuteWithMessages(t *testing.T) { + cases := []struct { + templates []string + values Values + expected string + }{ + { + []string{ + `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `, + `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`, + `{{- range .Messages }} +{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }} +{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} +{{- end }} +{{- end }}`, + }, + Values{ + Messages: []api.Message{ + {Role: "user", Content: "Hello friend!"}, + {Role: "assistant", Content: "Hello human!"}, + {Role: "user", Content: "Yay!"}, + }, + }, + `[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `, + }, + { + []string{ + `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `, + `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`, + ` +{{- range .Messages }} +{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }} +{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} +{{- end }} +{{- end }}`, + }, + Values{ + Messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant!"}, + {Role: "user", Content: "Hello friend!"}, + {Role: "assistant", Content: "Hello human!"}, + {Role: "user", Content: "Yay!"}, + }, + }, + `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant! + +Yay![/INST] `, + }, + { + []string{ + `{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|> +`, + ` +{{- range .Messages }} +{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system +{{ $.System }}<|im_end|>{{ print "\n" }} +{{- end }}<|im_start|>{{ .Role }} +{{ .Content }}<|im_end|>{{ print "\n" }} +{{- end }}<|im_start|>assistant +`, + }, + Values{ + Messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant!"}, + {Role: "user", Content: "Hello friend!"}, + {Role: "assistant", Content: "Hello human!"}, + {Role: "user", Content: "Yay!"}, + }, + }, + `<|im_start|>user +Hello friend!<|im_end|> +<|im_start|>assistant +Hello human!<|im_end|> +<|im_start|>system +You are a helpful assistant!<|im_end|> +<|im_start|>user +Yay!<|im_end|> +<|im_start|>assistant +`, + }, + { + []string{ + `{{ if .Prompt }}Question: {{ .Prompt }} + +{{ end }}Answer: {{ .Response }} + +`, + ` +{{- range .Messages }} +{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }} +{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }} +{{- end }} +{{- end }}Answer: `, + }, + Values{ + Messages: []api.Message{ + {Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}}, + {Role: "assistant", Content: "It's a hot dog."}, + {Role: "user", Content: "What's in _this_ image?"}, + {Role: "user", Images: []api.ImageData{[]byte("")}}, + {Role: "user", Content: "Is it a hot dog?"}, + }, + }, + `Question: [img-0] What's in this image? + +Answer: It's a hot dog. + +Question: What's in _this_ image? + +[img-1] + +Is it a hot dog? + +Answer: `, + }, + } + + for _, tt := range cases { + t.Run("", func(t *testing.T) { + for _, tmpl := range tt.templates { + t.Run("", func(t *testing.T) { + tmpl, err := Parse(tmpl) + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, tt.values); err != nil { + t.Fatal(err) + } + + if b.String() != tt.expected { + t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String()) + } + }) + } + }) + } +} From 2c3fe1fd972b7810091120f844afc35bc98accbd Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jun 2024 11:00:08 -0700 Subject: [PATCH 2/4] comments --- server/prompt.go | 29 +++--- server/prompt_test.go | 34 +++---- server/routes.go | 46 +++++----- template/template.go | 48 +++++----- template/template_test.go | 180 ++++++++++++++++++++++++++++++-------- 5 files changed, 224 insertions(+), 113 deletions(-) diff --git a/server/prompt.go b/server/prompt.go index 5016fbe1..51d691a9 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -11,8 +11,13 @@ import ( "github.com/ollama/ollama/template" ) -func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { - // extract system messages which should always be included +type tokenizeFunc func(context.Context, string) ([]int, error) + +// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. +// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the +// latest message and 2) system messages +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { + // pull out any system messages which should always be included in the prompt var system []api.Message msgs = slices.DeleteFunc(msgs, func(m api.Message) bool { if m.Role == "system" { @@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s return false }) - if len(system) == 0 && r.model.System != "" { + if len(system) == 0 && m.System != "" { // add model system prompt since it wasn't provided - system = append(system, api.Message{Role: "system", Content: r.model.System}) + system = append(system, api.Message{Role: "system", Content: m.System}) } + // always include the last message n := len(msgs) - 1 + // in reverse, find all messages that fit into context window for i := n - 1; i >= 0; i-- { var b bytes.Buffer - if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { return "", nil, err } - s, err := r.llama.Tokenize(ctx, b.String()) + s, err := tokenize(ctx, b.String()) if err != nil { return "", nil, err } c := len(s) - if r.model.ProjectorPaths != nil { + if m.ProjectorPaths != nil { for _, m := range msgs[i:] { - // TODO: get image embedding length from project metadata + // images are represented as 768 sized embeddings + // TODO: get embedding length from project metadata c += 768 * len(m.Images) } } - if c > r.NumCtx { + if c > opts.NumCtx { slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { @@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s } } + // truncate any messages that do not fit into the context window var b bytes.Buffer - if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index 59288b46..d4cee98c 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -7,15 +7,10 @@ import ( "testing" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/llm" "github.com/ollama/ollama/template" ) -type mock struct { - llm.LlamaServer -} - -func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) { +func tokenize(_ context.Context, s string) (tokens []int, err error) { for range strings.Fields(s) { tokens = append(tokens, len(tokens)) } @@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages", + name: "truncate messages", limit: 1, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with image", + name: "truncate messages with image", limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with images", + name: "truncate messages with images", limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, @@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with images", + name: "messages with images", limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, @@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with image tag", + name: "message with image tag", limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, @@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with interleaved images", + name: "messages with interleaved images", limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate message with interleaved images", + name: "truncate message with interleaved images", limit: 1024, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with system prompt", + name: "message with system prompt", limit: 2048, msgs: []api.Message{ {Role: "system", Content: "You are the Test Who Lived."}, @@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - r := runnerRef{ - llama: mock{}, - model: &Model{Template: tmpl, ProjectorPaths: []string{"vision"}}, - Options: &api.Options{}, - } - - r.NumCtx = tt.limit - prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs) + model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} + opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} + prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs) if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index 35e64511..1a93e977 100644 --- a/server/routes.go +++ b/server/routes.go @@ -54,6 +54,8 @@ func init() { gin.SetMode(mode) } +var errRequired = errors.New("is required") + func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { @@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { if name == "" { - return nil, errors.New("model is required") + return nil, fmt.Errorf("model %w", errRequired) } model, err := GetModel(name) @@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return } else if err != nil { - handleScheduleError(c, err) + handleScheduleError(c, req.Model, err) + return + } + + if req.Prompt == "" { + c.JSON(http.StatusOK, api.GenerateResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Done: true, + DoneReason: "load", + }) return } @@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) } - if req.Prompt != "" { - for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) - } - - msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) + for _, i := range images { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) } - if len(msgs) == 0 { - c.JSON(http.StatusOK, api.GenerateResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Done: true, - DoneReason: "load", - }) - return - } + msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) tmpl := r.model.Template if req.Template != "" { @@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { - handleScheduleError(c, err) + handleScheduleError(c, req.Model, err) return } @@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return } else if err != nil { - handleScheduleError(c, err) + handleScheduleError(c, req.Model, err) return } @@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages) + prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) { streamResponse(c, ch) } -func handleScheduleError(c *gin.Context, err error) { +func handleScheduleError(c *gin.Context, name string, err error) { switch { + case errors.Is(err, errRequired): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) case errors.Is(err, ErrMaxQueue): c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) + case errors.Is(err, os.ErrNotExist): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } diff --git a/template/template.go b/template/template.go index cfba5a23..c8f8f6d0 100644 --- a/template/template.go +++ b/template/template.go @@ -83,6 +83,7 @@ type Template struct { raw string } +// response is a template node that can be added to templates that don't already have one var response = parse.ActionNode{ NodeType: parse.NodeAction, Pipe: &parse.PipeNode{ @@ -101,28 +102,25 @@ var response = parse.ActionNode{ }, } +var funcs = template.FuncMap{ + "toJson": func(v any) string { + b, err := json.Marshal(v) + if err != nil { + return "" + } + + return string(b) + }, + "add": func(a, b int) int { + return a + b + }, + "sub": func(a, b int) int { + return a - b + }, +} + func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{ - "toJson": func(v any) string { - b, err := json.Marshal(v) - if err != nil { - return "" - } - - return string(b) - }, - "isLastMessage": func(s []*api.Message, m *api.Message) bool { - for i := len(s) - 1; i >= 0; i-- { - if m.Role != s[i].Role { - continue - } - - return m == s[i] - } - - return false - }, - }) + tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl, err := tmpl.Parse(s) if err != nil { @@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } -func collate(msgs []api.Message) (system string, collated []*api.Message) { +type messages []*api.Message + +// collate messages based on role. consecutive messages of the same role are merged +// into a single message. collate also pulls out and merges messages with Role == "system" +// which are templated separately. As a side effect, it mangles message content adding image +// tags ([img-%d]) as needed +func collate(msgs []api.Message) (system string, collated messages) { var n int for i := range msgs { msg := msgs[i] diff --git a/template/template_test.go b/template/template_test.go index 5d5dad4b..ac92bf48 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "slices" + "strconv" "testing" "text/template" @@ -15,6 +16,98 @@ import ( "github.com/ollama/ollama/llm" ) +func TestFuncs(t *testing.T) { + t.Run("toJson", func(t *testing.T) { + cases := []struct { + input any + expected string + }{ + {nil, "null"}, + {true, "true"}, + {false, "false"}, + {0, "0"}, + {1, "1"}, + {1.0, "1"}, + {1.1, "1.1"}, + {"", `""`}, + {"hello", `"hello"`}, + {[]int{1, 2, 3}, "[1,2,3]"}, + {[]string{"a", "b", "c"}, `["a","b","c"]`}, + {map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`}, + {map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`}, + } + + for _, tt := range cases { + t.Run(tt.expected, func(t *testing.T) { + toJson, ok := funcs["toJson"].(func(any) string) + if !ok { + t.Fatal("toJson is not a function") + } + + if s := toJson(tt.input); s != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, s) + } + }) + } + }) + + t.Run("add", func(t *testing.T) { + cases := []struct { + a, b int + expected int + }{ + {0, 0, 0}, + {0, 1, 1}, + {1, 0, 1}, + {1, 1, 2}, + {1, -1, 0}, + {-1, 1, 0}, + {-1, -1, -2}, + } + + for _, tt := range cases { + t.Run(strconv.Itoa(tt.expected), func(t *testing.T) { + add, ok := funcs["add"].(func(int, int) int) + if !ok { + t.Fatal("add is not a function") + } + + if n := add(tt.a, tt.b); n != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, n) + } + }) + } + }) + + t.Run("sub", func(t *testing.T) { + cases := []struct { + a, b int + expected int + }{ + {0, 0, 0}, + {0, 1, -1}, + {1, 0, 1}, + {1, 1, 0}, + {1, -1, 2}, + {-1, 1, -2}, + {-1, -1, 0}, + } + + for _, tt := range cases { + t.Run(strconv.Itoa(tt.expected), func(t *testing.T) { + sub, ok := funcs["sub"].(func(int, int) int) + if !ok { + t.Fatal("sub is not a function") + } + + if n := sub(tt.a, tt.b); n != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, n) + } + }) + } + }) +} + func TestNamed(t *testing.T) { f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) if err != nil { @@ -89,77 +182,86 @@ func TestParse(t *testing.T) { } func TestExecuteWithMessages(t *testing.T) { + type template struct { + name string + template string + } cases := []struct { - templates []string + name string + templates []template values Values expected string }{ { - []string{ - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `, - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`, - `{{- range .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }} + "mistral", + []template{ + {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", `{{- range .Messages }} +{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} -{{- end }}`, +{{- end }}`}, }, Values{ Messages: []api.Message{ {Role: "user", Content: "Hello friend!"}, {Role: "assistant", Content: "Hello human!"}, - {Role: "user", Content: "Yay!"}, + {Role: "user", Content: "What is your name?"}, }, }, - `[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `, + `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, }, { - []string{ - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `, - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`, - ` + "mistral system", + []template{ + {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", ` {{- range .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }} +{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} -{{- end }}`, +{{- end }}`}, }, Values{ Messages: []api.Message{ {Role: "system", Content: "You are a helpful assistant!"}, {Role: "user", Content: "Hello friend!"}, {Role: "assistant", Content: "Hello human!"}, - {Role: "user", Content: "Yay!"}, + {Role: "user", Content: "What is your name?"}, }, }, `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant! -Yay![/INST] `, +What is your name?[/INST] `, }, { - []string{ - `{{ if .System }}<|im_start|>system + "chatml", + []template{ + // this does not have a "no response" test because it's impossible to render the same output + {"response", `{{ if .System }}<|im_start|>system {{ .System }}<|im_end|> {{ end }}{{ if .Prompt }}<|im_start|>user {{ .Prompt }}<|im_end|> {{ end }}<|im_start|>assistant {{ .Response }}<|im_end|> -`, - ` +`}, + {"messages", ` {{- range .Messages }} -{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system -{{ $.System }}<|im_end|>{{ print "\n" }} +{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system +{{ $.System }}<|im_end|>{{ "\n" }} {{- end }}<|im_start|>{{ .Role }} -{{ .Content }}<|im_end|>{{ print "\n" }} +{{ .Content }}<|im_end|>{{ "\n" }} {{- end }}<|im_start|>assistant -`, +`}, }, Values{ Messages: []api.Message{ {Role: "system", Content: "You are a helpful assistant!"}, {Role: "user", Content: "Hello friend!"}, {Role: "assistant", Content: "Hello human!"}, - {Role: "user", Content: "Yay!"}, + {Role: "user", Content: "What is your name?"}, }, }, `<|im_start|>user @@ -169,23 +271,25 @@ Hello human!<|im_end|> <|im_start|>system You are a helpful assistant!<|im_end|> <|im_start|>user -Yay!<|im_end|> +What is your name?<|im_end|> <|im_start|>assistant `, }, { - []string{ - `{{ if .Prompt }}Question: {{ .Prompt }} + "moondream", + []template{ + // this does not have a "no response" test because it's impossible to render the same output + {"response", `{{ if .Prompt }}Question: {{ .Prompt }} {{ end }}Answer: {{ .Response }} -`, - ` +`}, + {"messages", ` {{- range .Messages }} -{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }} -{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }} +{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }} +{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }} {{- end }} -{{- end }}Answer: `, +{{- end }}Answer: `}, }, Values{ Messages: []api.Message{ @@ -211,10 +315,10 @@ Answer: `, } for _, tt := range cases { - t.Run("", func(t *testing.T) { - for _, tmpl := range tt.templates { - t.Run("", func(t *testing.T) { - tmpl, err := Parse(tmpl) + t.Run(tt.name, func(t *testing.T) { + for _, ttt := range tt.templates { + t.Run(ttt.name, func(t *testing.T) { + tmpl, err := Parse(ttt.template) if err != nil { t.Fatal(err) } From ac7a842e550721fbc00e36e416e7cf6606993149 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 09:00:07 -0700 Subject: [PATCH 3/4] fix model reloading ensure runtime model changes (template, system prompt, messages, options) are captured on model updates without needing to reload the server --- llm/server.go | 2 +- server/routes.go | 42 ++++++++++++++++++++++-------------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/llm/server.go b/llm/server.go index 206f9e39..229d61e4 100644 --- a/llm/server.go +++ b/llm/server.go @@ -679,7 +679,7 @@ type CompletionRequest struct { Prompt string Format string Images []ImageData - Options api.Options + Options *api.Options } type CompletionResponse struct { diff --git a/server/routes.go b/server/routes.go index 1a93e977..4059c7c5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options return opts, nil } -func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { +// scheduleRunner schedules a runner after validating inputs such as capabilities and model options. +// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. +func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { if name == "" { - return nil, fmt.Errorf("model %w", errRequired) + return nil, nil, nil, fmt.Errorf("model %w", errRequired) } model, err := GetModel(name) if err != nil { - return nil, err + return nil, nil, nil, err } if err := model.CheckCapabilities(caps...); err != nil { - return nil, fmt.Errorf("%s %w", name, err) + return nil, nil, nil, fmt.Errorf("%s %w", name, err) } opts, err := modelOptions(model, requestOpts) if err != nil { - return nil, err + return nil, nil, nil, err } runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) @@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil select { case runner = <-runnerCh: case err = <-errCh: - return nil, err + return nil, nil, nil, err } - return runner, nil + return runner.llama, model, &opts, nil } func (s *Server) GenerateHandler(c *gin.Context) { @@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} - r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return @@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { var msgs []api.Message if req.System != "" { msgs = append(msgs, api.Message{Role: "system", Content: req.System}) - } else if r.model.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) + } else if m.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: m.System}) } for _, i := range images { @@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) - tmpl := r.model.Template + tmpl := m.Template if req.Template != "" { tmpl, err = template.Parse(req.Template) if err != nil { @@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { var b bytes.Buffer if req.Context != nil { - s, err := r.llama.Detokenize(c.Request.Context(), req.Context) + s, err := r.Detokenize(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, Format: req.Format, - Options: *r.Options, + Options: opts, }, func(r llm.CompletionResponse) { ch <- api.GenerateResponse{ Model: req.Model, @@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt) + embedding, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) @@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} - r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return @@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, Format: req.Format, - Options: *r.Options, + Options: opts, }, func(r llm.CompletionResponse) { ch <- api.ChatResponse{ Model: req.Model, From 326363b3a72d9e2972a019dfc4c6147ea901f501 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 13:49:14 -0700 Subject: [PATCH 4/4] no funcs --- template/template.go | 19 +------ template/template_test.go | 105 +++----------------------------------- 2 files changed, 7 insertions(+), 117 deletions(-) diff --git a/template/template.go b/template/template.go index c8f8f6d0..b133b97e 100644 --- a/template/template.go +++ b/template/template.go @@ -102,25 +102,8 @@ var response = parse.ActionNode{ }, } -var funcs = template.FuncMap{ - "toJson": func(v any) string { - b, err := json.Marshal(v) - if err != nil { - return "" - } - - return string(b) - }, - "add": func(a, b int) int { - return a + b - }, - "sub": func(a, b int) int { - return a - b - }, -} - func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) + tmpl := template.New("").Option("missingkey=zero") tmpl, err := tmpl.Parse(s) if err != nil { diff --git a/template/template_test.go b/template/template_test.go index ac92bf48..ac16bd60 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -8,7 +8,6 @@ import ( "os" "path/filepath" "slices" - "strconv" "testing" "text/template" @@ -16,98 +15,6 @@ import ( "github.com/ollama/ollama/llm" ) -func TestFuncs(t *testing.T) { - t.Run("toJson", func(t *testing.T) { - cases := []struct { - input any - expected string - }{ - {nil, "null"}, - {true, "true"}, - {false, "false"}, - {0, "0"}, - {1, "1"}, - {1.0, "1"}, - {1.1, "1.1"}, - {"", `""`}, - {"hello", `"hello"`}, - {[]int{1, 2, 3}, "[1,2,3]"}, - {[]string{"a", "b", "c"}, `["a","b","c"]`}, - {map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`}, - {map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`}, - } - - for _, tt := range cases { - t.Run(tt.expected, func(t *testing.T) { - toJson, ok := funcs["toJson"].(func(any) string) - if !ok { - t.Fatal("toJson is not a function") - } - - if s := toJson(tt.input); s != tt.expected { - t.Errorf("expected %q, got %q", tt.expected, s) - } - }) - } - }) - - t.Run("add", func(t *testing.T) { - cases := []struct { - a, b int - expected int - }{ - {0, 0, 0}, - {0, 1, 1}, - {1, 0, 1}, - {1, 1, 2}, - {1, -1, 0}, - {-1, 1, 0}, - {-1, -1, -2}, - } - - for _, tt := range cases { - t.Run(strconv.Itoa(tt.expected), func(t *testing.T) { - add, ok := funcs["add"].(func(int, int) int) - if !ok { - t.Fatal("add is not a function") - } - - if n := add(tt.a, tt.b); n != tt.expected { - t.Errorf("expected %d, got %d", tt.expected, n) - } - }) - } - }) - - t.Run("sub", func(t *testing.T) { - cases := []struct { - a, b int - expected int - }{ - {0, 0, 0}, - {0, 1, -1}, - {1, 0, 1}, - {1, 1, 0}, - {1, -1, 2}, - {-1, 1, -2}, - {-1, -1, 0}, - } - - for _, tt := range cases { - t.Run(strconv.Itoa(tt.expected), func(t *testing.T) { - sub, ok := funcs["sub"].(func(int, int) int) - if !ok { - t.Fatal("sub is not a function") - } - - if n := sub(tt.a, tt.b); n != tt.expected { - t.Errorf("expected %d, got %d", tt.expected, n) - } - }) - } - }) -} - func TestNamed(t *testing.T) { f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) if err != nil { @@ -197,8 +104,8 @@ func TestExecuteWithMessages(t *testing.T) { []template{ {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- range .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} + {"messages", `{{- range $index, $_ := .Messages }} +{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} {{- end }}`}, @@ -218,8 +125,8 @@ func TestExecuteWithMessages(t *testing.T) { {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {"messages", ` -{{- range .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} +{{- range $index, $_ := .Messages }} +{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} {{- end }}`}, @@ -248,8 +155,8 @@ What is your name?[/INST] `, {{ .Response }}<|im_end|> `}, {"messages", ` -{{- range .Messages }} -{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system +{{- range $index, $_ := .Messages }} +{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system {{ $.System }}<|im_end|>{{ "\n" }} {{- end }}<|im_start|>{{ .Role }} {{ .Content }}<|im_end|>{{ "\n" }}