comments
This commit is contained in:
parent
269ed6e6a2
commit
2c3fe1fd97
5 changed files with 224 additions and 113 deletions
|
@ -11,8 +11,13 @@ import (
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||||
// extract system messages which should always be included
|
|
||||||
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||||
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
|
// latest message and 2) system messages
|
||||||
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
||||||
|
// pull out any system messages which should always be included in the prompt
|
||||||
var system []api.Message
|
var system []api.Message
|
||||||
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
||||||
if m.Role == "system" {
|
if m.Role == "system" {
|
||||||
|
@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
|
||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(system) == 0 && r.model.System != "" {
|
if len(system) == 0 && m.System != "" {
|
||||||
// add model system prompt since it wasn't provided
|
// add model system prompt since it wasn't provided
|
||||||
system = append(system, api.Message{Role: "system", Content: r.model.System})
|
system = append(system, api.Message{Role: "system", Content: m.System})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// always include the last message
|
||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
|
// in reverse, find all messages that fit into context window
|
||||||
for i := n - 1; i >= 0; i-- {
|
for i := n - 1; i >= 0; i-- {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := r.llama.Tokenize(ctx, b.String())
|
s, err := tokenize(ctx, b.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := len(s)
|
c := len(s)
|
||||||
if r.model.ProjectorPaths != nil {
|
if m.ProjectorPaths != nil {
|
||||||
for _, m := range msgs[i:] {
|
for _, m := range msgs[i:] {
|
||||||
// TODO: get image embedding length from project metadata
|
// images are represented as 768 sized embeddings
|
||||||
|
// TODO: get embedding length from project metadata
|
||||||
c += 768 * len(m.Images)
|
c += 768 * len(m.Images)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c > r.NumCtx {
|
if c > opts.NumCtx {
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
|
@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncate any messages that do not fit into the context window
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,15 +7,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mock struct {
|
func tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
llm.LlamaServer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
|
||||||
for range strings.Fields(s) {
|
for range strings.Fields(s) {
|
||||||
tokens = append(tokens, len(tokens))
|
tokens = append(tokens, len(tokens))
|
||||||
}
|
}
|
||||||
|
@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "truncate messages",
|
name: "truncate messages",
|
||||||
limit: 1,
|
limit: 1,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry!"},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "truncate messages with image",
|
name: "truncate messages with image",
|
||||||
limit: 64,
|
limit: 64,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry!"},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "truncate messages with images",
|
name: "truncate messages with images",
|
||||||
limit: 64,
|
limit: 64,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "messages with images",
|
name: "messages with images",
|
||||||
limit: 2048,
|
limit: 2048,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "message with image tag",
|
name: "message with image tag",
|
||||||
limit: 2048,
|
limit: 2048,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "messages with interleaved images",
|
name: "messages with interleaved images",
|
||||||
limit: 2048,
|
limit: 2048,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry!"},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "truncate message with interleaved images",
|
name: "truncate message with interleaved images",
|
||||||
limit: 1024,
|
limit: 1024,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "You're a test, Harry!"},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "message with system prompt",
|
name: "message with system prompt",
|
||||||
limit: 2048,
|
limit: 2048,
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are the Test Who Lived."},
|
{Role: "system", Content: "You are the Test Who Lived."},
|
||||||
|
@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
r := runnerRef{
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
llama: mock{},
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
model: &Model{Template: tmpl, ProjectorPaths: []string{"vision"}},
|
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
|
||||||
Options: &api.Options{},
|
|
||||||
}
|
|
||||||
|
|
||||||
r.NumCtx = tt.limit
|
|
||||||
prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,6 +54,8 @@ func init() {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errRequired = errors.New("is required")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
|
@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||||
|
|
||||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, errors.New("model is required")
|
return nil, fmt.Errorf("model %w", errRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(name)
|
model, err := GetModel(name)
|
||||||
|
@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
handleScheduleError(c, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Prompt == "" {
|
||||||
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "load",
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
|
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Prompt != "" {
|
for _, i := range images {
|
||||||
for _, i := range images {
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(msgs) == 0 {
|
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
||||||
Model: req.Model,
|
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
Done: true,
|
|
||||||
DoneReason: "load",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl := r.model.Template
|
tmpl := r.model.Template
|
||||||
if req.Template != "" {
|
if req.Template != "" {
|
||||||
|
@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
|
|
||||||
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
handleScheduleError(c, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
|
prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleScheduleError(c *gin.Context, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
switch {
|
switch {
|
||||||
|
case errors.Is(err, errRequired):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
c.JSON(499, gin.H{"error": "request canceled"})
|
c.JSON(499, gin.H{"error": "request canceled"})
|
||||||
case errors.Is(err, ErrMaxQueue):
|
case errors.Is(err, ErrMaxQueue):
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||||
|
case errors.Is(err, os.ErrNotExist):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,6 +83,7 @@ type Template struct {
|
||||||
raw string
|
raw string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// response is a template node that can be added to templates that don't already have one
|
||||||
var response = parse.ActionNode{
|
var response = parse.ActionNode{
|
||||||
NodeType: parse.NodeAction,
|
NodeType: parse.NodeAction,
|
||||||
Pipe: &parse.PipeNode{
|
Pipe: &parse.PipeNode{
|
||||||
|
@ -101,28 +102,25 @@ var response = parse.ActionNode{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var funcs = template.FuncMap{
|
||||||
|
"toJson": func(v any) string {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(b)
|
||||||
|
},
|
||||||
|
"add": func(a, b int) int {
|
||||||
|
return a + b
|
||||||
|
},
|
||||||
|
"sub": func(a, b int) int {
|
||||||
|
return a - b
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func Parse(s string) (*Template, error) {
|
func Parse(s string) (*Template, error) {
|
||||||
tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{
|
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
||||||
"toJson": func(v any) string {
|
|
||||||
b, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(b)
|
|
||||||
},
|
|
||||||
"isLastMessage": func(s []*api.Message, m *api.Message) bool {
|
|
||||||
for i := len(s) - 1; i >= 0; i-- {
|
|
||||||
if m.Role != s[i].Role {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return m == s[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
tmpl, err := tmpl.Parse(s)
|
tmpl, err := tmpl.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func collate(msgs []api.Message) (system string, collated []*api.Message) {
|
type messages []*api.Message
|
||||||
|
|
||||||
|
// collate messages based on role. consecutive messages of the same role are merged
|
||||||
|
// into a single message. collate also pulls out and merges messages with Role == "system"
|
||||||
|
// which are templated separately. As a side effect, it mangles message content adding image
|
||||||
|
// tags ([img-%d]) as needed
|
||||||
|
func collate(msgs []api.Message) (system string, collated messages) {
|
||||||
var n int
|
var n int
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
msg := msgs[i]
|
msg := msgs[i]
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
@ -15,6 +16,98 @@ import (
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestFuncs(t *testing.T) {
|
||||||
|
t.Run("toJson", func(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
input any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{nil, "null"},
|
||||||
|
{true, "true"},
|
||||||
|
{false, "false"},
|
||||||
|
{0, "0"},
|
||||||
|
{1, "1"},
|
||||||
|
{1.0, "1"},
|
||||||
|
{1.1, "1.1"},
|
||||||
|
{"", `""`},
|
||||||
|
{"hello", `"hello"`},
|
||||||
|
{[]int{1, 2, 3}, "[1,2,3]"},
|
||||||
|
{[]string{"a", "b", "c"}, `["a","b","c"]`},
|
||||||
|
{map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`},
|
||||||
|
{map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.expected, func(t *testing.T) {
|
||||||
|
toJson, ok := funcs["toJson"].(func(any) string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("toJson is not a function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := toJson(tt.input); s != tt.expected {
|
||||||
|
t.Errorf("expected %q, got %q", tt.expected, s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add", func(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
a, b int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{0, 0, 0},
|
||||||
|
{0, 1, 1},
|
||||||
|
{1, 0, 1},
|
||||||
|
{1, 1, 2},
|
||||||
|
{1, -1, 0},
|
||||||
|
{-1, 1, 0},
|
||||||
|
{-1, -1, -2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
|
||||||
|
add, ok := funcs["add"].(func(int, int) int)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("add is not a function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := add(tt.a, tt.b); n != tt.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tt.expected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sub", func(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
a, b int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{0, 0, 0},
|
||||||
|
{0, 1, -1},
|
||||||
|
{1, 0, 1},
|
||||||
|
{1, 1, 0},
|
||||||
|
{1, -1, 2},
|
||||||
|
{-1, 1, -2},
|
||||||
|
{-1, -1, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
|
||||||
|
sub, ok := funcs["sub"].(func(int, int) int)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("sub is not a function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := sub(tt.a, tt.b); n != tt.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tt.expected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestNamed(t *testing.T) {
|
func TestNamed(t *testing.T) {
|
||||||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -89,77 +182,86 @@ func TestParse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecuteWithMessages(t *testing.T) {
|
func TestExecuteWithMessages(t *testing.T) {
|
||||||
|
type template struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
}
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
templates []string
|
name string
|
||||||
|
templates []template
|
||||||
values Values
|
values Values
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
[]string{
|
"mistral",
|
||||||
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
|
[]template{
|
||||||
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
|
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||||
`{{- range .Messages }}
|
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
|
{"messages", `{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- end }}`,
|
{{- end }}`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{Role: "user", Content: "Hello friend!"},
|
{Role: "user", Content: "Hello friend!"},
|
||||||
{Role: "assistant", Content: "Hello human!"},
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
{Role: "user", Content: "Yay!"},
|
{Role: "user", Content: "What is your name?"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
`[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `,
|
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]string{
|
"mistral system",
|
||||||
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
|
[]template{
|
||||||
`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
|
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
||||||
`
|
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `
|
||||||
{{- range .Messages }}
|
{{- range .Messages }}
|
||||||
{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
|
{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
|
||||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- end }}`,
|
{{- end }}`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{Role: "system", Content: "You are a helpful assistant!"},
|
{Role: "system", Content: "You are a helpful assistant!"},
|
||||||
{Role: "user", Content: "Hello friend!"},
|
{Role: "user", Content: "Hello friend!"},
|
||||||
{Role: "assistant", Content: "Hello human!"},
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
{Role: "user", Content: "Yay!"},
|
{Role: "user", Content: "What is your name?"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
|
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
|
||||||
|
|
||||||
Yay![/INST] `,
|
What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]string{
|
"chatml",
|
||||||
`{{ if .System }}<|im_start|>system
|
[]template{
|
||||||
|
// this does not have a "no response" test because it's impossible to render the same output
|
||||||
|
{"response", `{{ if .System }}<|im_start|>system
|
||||||
{{ .System }}<|im_end|>
|
{{ .System }}<|im_end|>
|
||||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
{{ .Prompt }}<|im_end|>
|
{{ .Prompt }}<|im_end|>
|
||||||
{{ end }}<|im_start|>assistant
|
{{ end }}<|im_start|>assistant
|
||||||
{{ .Response }}<|im_end|>
|
{{ .Response }}<|im_end|>
|
||||||
`,
|
`},
|
||||||
`
|
{"messages", `
|
||||||
{{- range .Messages }}
|
{{- range .Messages }}
|
||||||
{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system
|
{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system
|
||||||
{{ $.System }}<|im_end|>{{ print "\n" }}
|
{{ $.System }}<|im_end|>{{ "\n" }}
|
||||||
{{- end }}<|im_start|>{{ .Role }}
|
{{- end }}<|im_start|>{{ .Role }}
|
||||||
{{ .Content }}<|im_end|>{{ print "\n" }}
|
{{ .Content }}<|im_end|>{{ "\n" }}
|
||||||
{{- end }}<|im_start|>assistant
|
{{- end }}<|im_start|>assistant
|
||||||
`,
|
`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{Role: "system", Content: "You are a helpful assistant!"},
|
{Role: "system", Content: "You are a helpful assistant!"},
|
||||||
{Role: "user", Content: "Hello friend!"},
|
{Role: "user", Content: "Hello friend!"},
|
||||||
{Role: "assistant", Content: "Hello human!"},
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
{Role: "user", Content: "Yay!"},
|
{Role: "user", Content: "What is your name?"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
`<|im_start|>user
|
`<|im_start|>user
|
||||||
|
@ -169,23 +271,25 @@ Hello human!<|im_end|>
|
||||||
<|im_start|>system
|
<|im_start|>system
|
||||||
You are a helpful assistant!<|im_end|>
|
You are a helpful assistant!<|im_end|>
|
||||||
<|im_start|>user
|
<|im_start|>user
|
||||||
Yay!<|im_end|>
|
What is your name?<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]string{
|
"moondream",
|
||||||
`{{ if .Prompt }}Question: {{ .Prompt }}
|
[]template{
|
||||||
|
// this does not have a "no response" test because it's impossible to render the same output
|
||||||
|
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}Answer: {{ .Response }}
|
{{ end }}Answer: {{ .Response }}
|
||||||
|
|
||||||
`,
|
`},
|
||||||
`
|
{"messages", `
|
||||||
{{- range .Messages }}
|
{{- range .Messages }}
|
||||||
{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }}
|
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
|
||||||
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }}
|
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- end }}Answer: `,
|
{{- end }}Answer: `},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
|
@ -211,10 +315,10 @@ Answer: `,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
for _, tmpl := range tt.templates {
|
for _, ttt := range tt.templates {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run(ttt.name, func(t *testing.T) {
|
||||||
tmpl, err := Parse(tmpl)
|
tmpl, err := Parse(ttt.template)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue