Merge pull request #5126 from ollama/mxyng/messages
update message processing
This commit is contained in:
commit
9bbddc37a7
7 changed files with 681 additions and 713 deletions
|
@ -679,7 +679,7 @@ type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format string
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options api.Options
|
Options *api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
|
|
|
@ -34,6 +34,8 @@ import (
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errCapabilityCompletion = errors.New("completion")
|
||||||
|
|
||||||
type Capability string
|
type Capability string
|
||||||
|
|
||||||
const CapabilityCompletion = Capability("completion")
|
const CapabilityCompletion = Capability("completion")
|
||||||
|
@ -62,7 +64,10 @@ type Model struct {
|
||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Has(caps ...Capability) bool {
|
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||||
|
// any missing or unknown capabilities
|
||||||
|
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
|
var errs []error
|
||||||
for _, cap := range caps {
|
for _, cap := range caps {
|
||||||
switch cap {
|
switch cap {
|
||||||
case CapabilityCompletion:
|
case CapabilityCompletion:
|
||||||
|
@ -81,15 +86,19 @@ func (m *Model) Has(caps ...Capability) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||||
return false
|
errs = append(errs, errCapabilityCompletion)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
return false
|
return fmt.Errorf("unknown capability: %s", cap)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
if err := errors.Join(errs...); err != nil {
|
||||||
|
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) String() string {
|
func (m *Model) String() string {
|
||||||
|
|
250
server/prompt.go
250
server/prompt.go
|
@ -1,217 +1,83 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"bytes"
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"slices"
|
||||||
|
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// isResponseNode checks if the node contains .Response
|
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||||
func isResponseNode(node *parse.ActionNode) bool {
|
|
||||||
for _, cmd := range node.Pipe.Cmds {
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||||
for _, arg := range cmd.Args {
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
|
// latest message and 2) system messages
|
||||||
if fieldNode.Ident[0] == "Response" {
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
||||||
return true
|
// pull out any system messages which should always be included in the prompt
|
||||||
}
|
var system []api.Message
|
||||||
}
|
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
||||||
|
if m.Role == "system" {
|
||||||
|
system = append(system, m)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatTemplateForResponse formats the template AST to:
|
return false
|
||||||
// 1. remove all nodes after the first .Response (if generate=true)
|
})
|
||||||
// 2. add a .Response node to the end if it doesn't exist
|
|
||||||
// TODO(jmorganca): this should recursively cut the template before the first .Response
|
if len(system) == 0 && m.System != "" {
|
||||||
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
// add model system prompt since it wasn't provided
|
||||||
var found bool
|
system = append(system, api.Message{Role: "system", Content: m.System})
|
||||||
for i, node := range tmpl.Tree.Root.Nodes {
|
}
|
||||||
if actionNode, ok := node.(*parse.ActionNode); ok {
|
|
||||||
if isResponseNode(actionNode) {
|
// always include the last message
|
||||||
found = true
|
n := len(msgs) - 1
|
||||||
if generate {
|
// in reverse, find all messages that fit into context window
|
||||||
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
|
for i := n - 1; i >= 0; i-- {
|
||||||
break
|
var b bytes.Buffer
|
||||||
}
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
||||||
}
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
s, err := tokenize(ctx, b.String())
|
||||||
// add the response node if it doesn't exist
|
|
||||||
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
|
|
||||||
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
|
|
||||||
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
|
|
||||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prompt renders a prompt from a template. If generate is set to true,
|
|
||||||
// the response and parts of the template following it are not rendered
|
|
||||||
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
|
|
||||||
formatTemplateForResponse(tmpl, generate)
|
|
||||||
|
|
||||||
vars := map[string]any{
|
|
||||||
"System": system,
|
|
||||||
"Prompt": prompt,
|
|
||||||
"Response": response,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
|
||||||
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := encode(rendered)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to encode prompt", "err", err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(tokens), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
|
||||||
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
|
||||||
type prompt struct {
|
|
||||||
System string
|
|
||||||
Prompt string
|
|
||||||
Response string
|
|
||||||
|
|
||||||
images []int
|
|
||||||
tokens int
|
|
||||||
}
|
|
||||||
|
|
||||||
var p prompt
|
|
||||||
|
|
||||||
// iterate through messages to build up {system,user,response} prompts
|
|
||||||
var imgId int
|
|
||||||
var prompts []prompt
|
|
||||||
for _, msg := range messages {
|
|
||||||
switch strings.ToLower(msg.Role) {
|
|
||||||
case "system":
|
|
||||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
p = prompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.System = msg.Content
|
|
||||||
case "user":
|
|
||||||
if p.Prompt != "" || p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
p = prompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
for range msg.Images {
|
|
||||||
fmt.Fprintf(&sb, "[img-%d] ", imgId)
|
|
||||||
p.images = append(p.images, imgId)
|
|
||||||
imgId += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(msg.Content)
|
|
||||||
p.Prompt = sb.String()
|
|
||||||
case "assistant":
|
|
||||||
if p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
p = prompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Response = msg.Content
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// add final prompt
|
|
||||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculate token lengths for each prompt, estimating 768 tokens per images
|
|
||||||
for i, p := range prompts {
|
|
||||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
c := len(s)
|
||||||
}
|
if m.ProjectorPaths != nil {
|
||||||
|
for _, m := range msgs[i:] {
|
||||||
// truncate images and prompts starting from the beginning of the list
|
// images are represented as 768 sized embeddings
|
||||||
// until either one prompt remains or the total tokens fits the context window
|
// TODO: get embedding length from project metadata
|
||||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
c += 768 * len(m.Images)
|
||||||
for {
|
}
|
||||||
var required int
|
|
||||||
for _, p := range prompts {
|
|
||||||
required += p.tokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
required += 1 // for bos token
|
if c > opts.NumCtx {
|
||||||
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||||
if required <= window {
|
|
||||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
|
||||||
break
|
break
|
||||||
|
} else {
|
||||||
|
n = i
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt := &prompts[0]
|
|
||||||
|
|
||||||
if len(prompt.images) > 1 {
|
|
||||||
img := prompt.images[0]
|
|
||||||
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
|
|
||||||
prompt.images = prompt.images[1:]
|
|
||||||
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
|
|
||||||
prompt.tokens -= 768
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(prompts) > 1 {
|
|
||||||
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
|
|
||||||
system := prompt.System
|
|
||||||
prompts = prompts[1:]
|
|
||||||
|
|
||||||
if system != "" && prompts[0].System == "" {
|
|
||||||
prompts[0].System = system
|
|
||||||
|
|
||||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop truncating if there's only one prompt left
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
// truncate any messages that do not fit into the context window
|
||||||
for i, p := range prompts {
|
var b bytes.Buffer
|
||||||
// last prompt should leave the response unrendered (for completion)
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
||||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
return "", nil, err
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
sb.WriteString(rendered)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
for _, m := range msgs[n:] {
|
||||||
|
for _, i := range m.Images {
|
||||||
|
images = append(images, llm.ImageData{
|
||||||
|
ID: len(images),
|
||||||
|
Data: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), images, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -8,208 +10,195 @@ import (
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrompt(t *testing.T) {
|
func tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
tests := []struct {
|
for range strings.Fields(s) {
|
||||||
name string
|
tokens = append(tokens, len(tokens))
|
||||||
template string
|
|
||||||
system string
|
|
||||||
prompt string
|
|
||||||
response string
|
|
||||||
generate bool
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "simple prompt",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "implicit response",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "response",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "cut",
|
|
||||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
generate: true,
|
|
||||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nocut",
|
|
||||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
return
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
tmpl, err := template.Parse(tc.template)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("got = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
tests := []struct {
|
type expect struct {
|
||||||
name string
|
prompt string
|
||||||
template string
|
images [][]byte
|
||||||
messages []api.Message
|
}
|
||||||
window int
|
|
||||||
want string
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
limit int
|
||||||
|
msgs []api.Message
|
||||||
|
expect
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple prompt",
|
name: "messages",
|
||||||
template: "[INST] {{ .Prompt }} [/INST]",
|
limit: 64,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] Hello [/INST]",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with system message",
|
name: "truncate messages",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
limit: 1,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with response",
|
name: "truncate messages with image",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
|
limit: 64,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "I am?"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with implicit response",
|
name: "truncate messages with images",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
limit: 64,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "I am?"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with conversation",
|
name: "messages with images",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||||
{Role: "user", Content: "What are the potion ingredients?"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "sugar"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
{Role: "user", Content: "Anything else?"},
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with truncation",
|
name: "message with image tag",
|
||||||
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "I am?"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
{Role: "user", Content: "Why is the sky blue?"},
|
},
|
||||||
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 10,
|
|
||||||
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "images",
|
name: "messages with interleaved images",
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "You are a Wizard. [img-0] Hello",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "images truncated",
|
name: "truncate message with interleaved images",
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
limit: 1024,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "You are a Wizard. [img-0] [img-1] Hello",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty list",
|
name: "message with system prompt",
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
limit: 2048,
|
||||||
messages: []api.Message{},
|
msgs: []api.Message{
|
||||||
window: 1024,
|
{Role: "system", Content: "You are the Test Who Lived."},
|
||||||
want: "",
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
name: "empty prompt",
|
},
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
expect: expect{
|
||||||
messages: []api.Message{
|
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||||
{Role: "user", Content: ""},
|
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
encode := func(s string) ([]int, error) {
|
tmpl, err := template.Parse(`
|
||||||
words := strings.Fields(s)
|
{{- if .System }}{{ .System }} {{ end }}
|
||||||
return make([]int, len(words)), nil
|
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tt := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tmpl, err := template.Parse(tc.template)
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
|
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
|
if tt.prompt != prompt {
|
||||||
if err != nil {
|
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
||||||
t.Errorf("error = %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if got != tc.want {
|
if len(images) != len(tt.images) {
|
||||||
t.Errorf("got: %q, want: %q", got, tc.want)
|
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range images {
|
||||||
|
if images[i].ID != i {
|
||||||
|
t.Errorf("expected ID %d, got %d", i, images[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(images[i].Data, tt.images[i]) {
|
||||||
|
t.Errorf("expected %q, got %q", tt.images[i], images[i])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
498
server/routes.go
498
server/routes.go
|
@ -1,13 +1,13 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -54,6 +54,8 @@ func init() {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errRequired = errors.New("is required")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
|
@ -67,163 +69,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSupportedImageType(image []byte) bool {
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||||
contentType := http.DetectContentType(image)
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||||
return slices.Contains(allowedTypes, contentType)
|
if name == "" {
|
||||||
|
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetModel(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := model.CheckCapabilities(caps...); err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts, err := modelOptions(model, requestOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-runnerCh:
|
||||||
|
case err = <-errCh:
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
|
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate the request
|
if req.Format != "" && req.Format != "json" {
|
||||||
switch {
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||||
case req.Model == "":
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
return
|
||||||
case len(req.Format) > 0 && req.Format != "json":
|
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
||||||
return
|
|
||||||
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, img := range req.Images {
|
caps := []Capability{CapabilityCompletion}
|
||||||
if !isSupportedImageType(img) {
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
return
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
}
|
return
|
||||||
}
|
} else if err != nil {
|
||||||
|
handleScheduleError(c, req.Model, err)
|
||||||
model, err := GetModel(req.Model)
|
|
||||||
if err != nil {
|
|
||||||
var pErr *fs.PathError
|
|
||||||
if errors.As(err, &pErr) {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !model.Has(CapabilityCompletion) {
|
if req.Prompt == "" {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opts, err := modelOptions(model, req.Options)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-rCh:
|
|
||||||
case err = <-eCh:
|
|
||||||
handleErrorResponse(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// an empty request loads the model
|
|
||||||
// note: for a short while template was used in lieu
|
|
||||||
// of `raw` mode so we need to check for it too
|
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: "load",
|
DoneReason: "load",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl, err := template.Parse(req.Template)
|
images := make([]llm.ImageData, len(req.Images))
|
||||||
if err != nil {
|
for i := range req.Images {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
prompt := req.Prompt
|
||||||
|
if !req.Raw {
|
||||||
var prompt string
|
var msgs []api.Message
|
||||||
switch {
|
if req.System != "" {
|
||||||
case req.Raw:
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||||
prompt = req.Prompt
|
} else if m.System != "" {
|
||||||
case req.Prompt != "":
|
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||||
if req.Template == "" {
|
|
||||||
tmpl = model.Template
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.System == "" {
|
for _, i := range images {
|
||||||
req.System = model.System
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("generate handler", "prompt", req.Prompt)
|
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
slog.Debug("generate handler", "template", req.Template)
|
|
||||||
slog.Debug("generate handler", "system", req.System)
|
|
||||||
|
|
||||||
var sb strings.Builder
|
tmpl := m.Template
|
||||||
for i := range req.Images {
|
if req.Template != "" {
|
||||||
fmt.Fprintf(&sb, "[img-%d] ", i)
|
tmpl, err = template.Parse(req.Template)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(req.Prompt)
|
var b bytes.Buffer
|
||||||
|
|
||||||
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.Reset()
|
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(prev)
|
b.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(p)
|
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
prompt = sb.String()
|
prompt = b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("generate handler", "prompt", prompt)
|
slog.Debug("generate request", "prompt", prompt, "images", images)
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
var generated strings.Builder
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
fn := func(r llm.CompletionResponse) {
|
Prompt: prompt,
|
||||||
// Build up the full response
|
Images: images,
|
||||||
if _, err := generated.WriteString(r.Content); err != nil {
|
Format: req.Format,
|
||||||
ch <- gin.H{"error": err.Error()}
|
Options: opts,
|
||||||
return
|
}, func(r llm.CompletionResponse) {
|
||||||
}
|
ch <- api.GenerateResponse{
|
||||||
|
|
||||||
resp := api.GenerateResponse{
|
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: r.Done,
|
|
||||||
Response: r.Content,
|
Response: r.Content,
|
||||||
|
Done: r.Done,
|
||||||
DoneReason: r.DoneReason,
|
DoneReason: r.DoneReason,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
|
@ -232,77 +211,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
}); err != nil {
|
||||||
if r.Done {
|
|
||||||
resp.TotalDuration = time.Since(checkpointStart)
|
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
||||||
|
|
||||||
if !req.Raw {
|
|
||||||
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO (jmorganca): encode() should not strip special tokens
|
|
||||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
|
||||||
if err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Context = append(req.Context, tokens...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ch <- resp
|
|
||||||
}
|
|
||||||
|
|
||||||
var images []llm.ImageData
|
|
||||||
for i := range req.Images {
|
|
||||||
images = append(images, llm.ImageData{
|
|
||||||
ID: i,
|
|
||||||
Data: req.Images[i],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start prediction
|
|
||||||
req := llm.CompletionRequest{
|
|
||||||
Prompt: prompt,
|
|
||||||
Format: req.Format,
|
|
||||||
Images: images,
|
|
||||||
Options: opts,
|
|
||||||
}
|
|
||||||
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
// Accumulate responses into the final response
|
var r api.GenerateResponse
|
||||||
var final api.GenerateResponse
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for resp := range ch {
|
for rr := range ch {
|
||||||
switch r := resp.(type) {
|
switch t := rr.(type) {
|
||||||
case api.GenerateResponse:
|
case api.GenerateResponse:
|
||||||
sb.WriteString(r.Response)
|
sb.WriteString(t.Response)
|
||||||
final = r
|
r = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
msg, ok := t["error"].(string)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
if !ok {
|
||||||
return
|
msg = "unexpected error format in response"
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.Response = sb.String()
|
r.Response = sb.String()
|
||||||
c.JSON(http.StatusOK, final)
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,44 +248,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model == "" {
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
handleScheduleError(c, req.Model, err)
|
||||||
if errors.As(err, &pErr) {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opts, err := modelOptions(model, req.Options)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-rCh:
|
|
||||||
case err = <-eCh:
|
|
||||||
handleErrorResponse(c, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,17 +268,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
|
||||||
Embedding: embedding,
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||||
|
@ -649,9 +556,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs := make([]api.Message, 0)
|
msgs := make([]api.Message, len(m.Messages))
|
||||||
for _, msg := range m.Messages {
|
for i, msg := range m.Messages {
|
||||||
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||||
}
|
}
|
||||||
|
|
||||||
n := model.ParseName(req.Model)
|
n := model.ParseName(req.Model)
|
||||||
|
@ -1214,132 +1121,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
||||||
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
|
||||||
encode := func(s string) ([]int, error) {
|
|
||||||
return runner.llama.Tokenize(ctx, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
|
||||||
|
|
||||||
var req api.ChatRequest
|
var req api.ChatRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate the request
|
caps := []Capability{CapabilityCompletion}
|
||||||
switch {
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
case req.Model == "":
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
return
|
return
|
||||||
case len(req.Format) > 0 && req.Format != "json":
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
if len(req.Messages) == 0 {
|
||||||
if err != nil {
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
var pErr *fs.PathError
|
|
||||||
if errors.As(err, &pErr) {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !model.Has(CapabilityCompletion) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opts, err := modelOptions(model, req.Options)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-rCh:
|
|
||||||
case err = <-eCh:
|
|
||||||
handleErrorResponse(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
|
||||||
|
|
||||||
// if the first message is not a system message, then add the model's default system message
|
|
||||||
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
|
||||||
req.Messages = append([]api.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: model.System,
|
|
||||||
},
|
|
||||||
}, req.Messages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// an empty request loads the model
|
|
||||||
if len(req.Messages) == 0 || prompt == "" {
|
|
||||||
resp := api.ChatResponse{
|
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Message: api.Message{Role: "assistant"},
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: "load",
|
DoneReason: "load",
|
||||||
Message: api.Message{Role: "assistant"},
|
})
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, resp)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// only send images that are in the prompt
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
||||||
var i int
|
if err != nil {
|
||||||
var images []llm.ImageData
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
for _, m := range req.Messages {
|
return
|
||||||
for _, img := range m.Images {
|
|
||||||
if !isSupportedImageType(img) {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
|
||||||
images = append(images, llm.ImageData{Data: img, ID: i})
|
|
||||||
}
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
fn := func(r llm.CompletionResponse) {
|
Prompt: prompt,
|
||||||
resp := api.ChatResponse{
|
Images: images,
|
||||||
|
Format: req.Format,
|
||||||
|
Options: opts,
|
||||||
|
}, func(r llm.CompletionResponse) {
|
||||||
|
ch <- api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
|
@ -1352,64 +1182,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
}); err != nil {
|
||||||
if r.Done {
|
|
||||||
resp.TotalDuration = time.Since(checkpointStart)
|
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
||||||
}
|
|
||||||
|
|
||||||
ch <- resp
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
||||||
Prompt: prompt,
|
|
||||||
Format: req.Format,
|
|
||||||
Images: images,
|
|
||||||
Options: opts,
|
|
||||||
}, fn); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
// Accumulate responses into the final response
|
var r api.ChatResponse
|
||||||
var final api.ChatResponse
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for resp := range ch {
|
for rr := range ch {
|
||||||
switch r := resp.(type) {
|
switch t := rr.(type) {
|
||||||
case api.ChatResponse:
|
case api.ChatResponse:
|
||||||
sb.WriteString(r.Message.Content)
|
sb.WriteString(t.Message.Content)
|
||||||
final = r
|
r = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
msg, ok := t["error"].(string)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
if !ok {
|
||||||
return
|
msg = "unexpected error format in response"
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
r.Message.Content = sb.String()
|
||||||
c.JSON(http.StatusOK, final)
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleErrorResponse(c *gin.Context, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
if errors.Is(err, context.Canceled) {
|
switch {
|
||||||
|
case errors.Is(err, errRequired):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
c.JSON(499, gin.H{"error": "request canceled"})
|
c.JSON(499, gin.H{"error": "request canceled"})
|
||||||
return
|
case errors.Is(err, ErrMaxQueue):
|
||||||
}
|
|
||||||
if errors.Is(err, ErrMaxQueue) {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||||
return
|
case errors.Is(err, os.ErrNotExist):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/agnivade/levenshtein"
|
"github.com/agnivade/levenshtein"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -74,30 +76,59 @@ func Named(s string) (*named, error) {
|
||||||
return nil, errors.New("no matching template found")
|
return nil, errors.New("no matching template found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||||
|
|
||||||
type Template struct {
|
type Template struct {
|
||||||
*template.Template
|
*template.Template
|
||||||
raw string
|
raw string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// response is a template node that can be added to templates that don't already have one
|
||||||
|
var response = parse.ActionNode{
|
||||||
|
NodeType: parse.NodeAction,
|
||||||
|
Pipe: &parse.PipeNode{
|
||||||
|
NodeType: parse.NodePipe,
|
||||||
|
Cmds: []*parse.CommandNode{
|
||||||
|
{
|
||||||
|
NodeType: parse.NodeCommand,
|
||||||
|
Args: []parse.Node{
|
||||||
|
&parse.FieldNode{
|
||||||
|
NodeType: parse.NodeField,
|
||||||
|
Ident: []string{"Response"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func Parse(s string) (*Template, error) {
|
||||||
|
tmpl := template.New("").Option("missingkey=zero")
|
||||||
|
|
||||||
|
tmpl, err := tmpl.Parse(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t := Template{Template: tmpl, raw: s}
|
||||||
|
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
||||||
|
// touch up the template and append {{ .Response }}
|
||||||
|
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Template) String() string {
|
func (t *Template) String() string {
|
||||||
return t.raw
|
return t.raw
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
|
||||||
|
|
||||||
func Parse(s string) (*Template, error) {
|
|
||||||
t, err := template.New("").Option("missingkey=zero").Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Template{Template: t, raw: s}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Template) Vars() []string {
|
func (t *Template) Vars() []string {
|
||||||
var vars []string
|
var vars []string
|
||||||
for _, n := range t.Tree.Root.Nodes {
|
for _, tt := range t.Templates() {
|
||||||
vars = append(vars, parseNode(n)...)
|
for _, n := range tt.Root.Nodes {
|
||||||
|
vars = append(vars, parseNode(n)...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
set := make(map[string]struct{})
|
set := make(map[string]struct{})
|
||||||
|
@ -110,6 +141,103 @@ func (t *Template) Vars() []string {
|
||||||
return vars
|
return vars
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Values struct {
|
||||||
|
Messages []api.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
|
system, collated := collate(v.Messages)
|
||||||
|
if slices.Contains(t.Vars(), "messages") {
|
||||||
|
return t.Template.Execute(w, map[string]any{
|
||||||
|
"System": system,
|
||||||
|
"Messages": collated,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
var prompt, response string
|
||||||
|
for i, m := range collated {
|
||||||
|
if m.Role == "user" {
|
||||||
|
prompt = m.Content
|
||||||
|
} else {
|
||||||
|
response = m.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
if i != len(collated)-1 && prompt != "" && response != "" {
|
||||||
|
if err := t.Template.Execute(&b, map[string]any{
|
||||||
|
"System": "",
|
||||||
|
"Prompt": prompt,
|
||||||
|
"Response": response,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = ""
|
||||||
|
response = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cut bool
|
||||||
|
tree := t.Template.Copy()
|
||||||
|
// for the last message, cut everything after "{{ .Response }}"
|
||||||
|
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
|
||||||
|
if slices.Contains(parseNode(n), "Response") {
|
||||||
|
cut = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return cut
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
|
||||||
|
"System": system,
|
||||||
|
"Prompt": prompt,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := io.Copy(w, &b)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type messages []*api.Message
|
||||||
|
|
||||||
|
// collate messages based on role. consecutive messages of the same role are merged
|
||||||
|
// into a single message. collate also pulls out and merges messages with Role == "system"
|
||||||
|
// which are templated separately. As a side effect, it mangles message content adding image
|
||||||
|
// tags ([img-%d]) as needed
|
||||||
|
func collate(msgs []api.Message) (system string, collated messages) {
|
||||||
|
var n int
|
||||||
|
for i := range msgs {
|
||||||
|
msg := msgs[i]
|
||||||
|
if msg.Role == "system" {
|
||||||
|
if system != "" {
|
||||||
|
system += "\n\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
system += msg.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for range msg.Images {
|
||||||
|
imageTag := fmt.Sprintf("[img-%d]", n)
|
||||||
|
if !strings.Contains(msg.Content, "[img]") {
|
||||||
|
msg.Content = strings.TrimSpace("[img] " + msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
||||||
|
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
||||||
|
} else {
|
||||||
|
collated = append(collated, &msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func parseNode(n parse.Node) []string {
|
func parseNode(n parse.Node) []string {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
case *parse.ActionNode:
|
case *parse.ActionNode:
|
||||||
|
@ -152,6 +280,8 @@ func parseNode(n parse.Node) []string {
|
||||||
return names
|
return names
|
||||||
case *parse.FieldNode:
|
case *parse.FieldNode:
|
||||||
return n.Ident
|
return n.Ident
|
||||||
|
case *parse.TemplateNode:
|
||||||
|
return parseNode(n.Pipe)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,13 +65,12 @@ func TestParse(t *testing.T) {
|
||||||
template string
|
template string
|
||||||
vars []string
|
vars []string
|
||||||
}{
|
}{
|
||||||
{"{{ .Prompt }}", []string{"prompt"}},
|
{"{{ .Prompt }}", []string{"prompt", "response"}},
|
||||||
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
|
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
|
||||||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
|
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
||||||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
|
@ -87,3 +87,159 @@ func TestParse(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithMessages(t *testing.T) {
|
||||||
|
type template struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
templates []template
|
||||||
|
values Values
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"mistral",
|
||||||
|
[]template{
|
||||||
|
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||||
|
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `{{- range $index, $_ := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||||
|
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}`},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello friend!"},
|
||||||
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
|
{Role: "user", Content: "What is your name?"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mistral system",
|
||||||
|
[]template{
|
||||||
|
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||||
|
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `
|
||||||
|
{{- range $index, $_ := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||||
|
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}`},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant!"},
|
||||||
|
{Role: "user", Content: "Hello friend!"},
|
||||||
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
|
{Role: "user", Content: "What is your name?"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
|
||||||
|
|
||||||
|
What is your name?[/INST] `,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chatml",
|
||||||
|
[]template{
|
||||||
|
// this does not have a "no response" test because it's impossible to render the same output
|
||||||
|
{"response", `{{ if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
|
{{ .Prompt }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ .Response }}<|im_end|>
|
||||||
|
`},
|
||||||
|
{"messages", `
|
||||||
|
{{- range $index, $_ := .Messages }}
|
||||||
|
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
|
||||||
|
{{ $.System }}<|im_end|>{{ "\n" }}
|
||||||
|
{{- end }}<|im_start|>{{ .Role }}
|
||||||
|
{{ .Content }}<|im_end|>{{ "\n" }}
|
||||||
|
{{- end }}<|im_start|>assistant
|
||||||
|
`},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant!"},
|
||||||
|
{Role: "user", Content: "Hello friend!"},
|
||||||
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
|
{Role: "user", Content: "What is your name?"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`<|im_start|>user
|
||||||
|
Hello friend!<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
Hello human!<|im_end|>
|
||||||
|
<|im_start|>system
|
||||||
|
You are a helpful assistant!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What is your name?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"moondream",
|
||||||
|
[]template{
|
||||||
|
// this does not have a "no response" test because it's impossible to render the same output
|
||||||
|
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
|
||||||
|
|
||||||
|
{{ end }}Answer: {{ .Response }}
|
||||||
|
|
||||||
|
`},
|
||||||
|
{"messages", `
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
|
||||||
|
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}Answer: `},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
|
||||||
|
{Role: "assistant", Content: "It's a hot dog."},
|
||||||
|
{Role: "user", Content: "What's in _this_ image?"},
|
||||||
|
{Role: "user", Images: []api.ImageData{[]byte("")}},
|
||||||
|
{Role: "user", Content: "Is it a hot dog?"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Question: [img-0] What's in this image?
|
||||||
|
|
||||||
|
Answer: It's a hot dog.
|
||||||
|
|
||||||
|
Question: What's in _this_ image?
|
||||||
|
|
||||||
|
[img-1]
|
||||||
|
|
||||||
|
Is it a hot dog?
|
||||||
|
|
||||||
|
Answer: `,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
for _, ttt := range tt.templates {
|
||||||
|
t.Run(ttt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := Parse(ttt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.String() != tt.expected {
|
||||||
|
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue