Merge pull request #5207 from ollama/mxyng/suffix
add insert support to generate endpoint
This commit is contained in:
commit
cd0853f2d5
6 changed files with 155 additions and 27 deletions
|
@ -47,6 +47,9 @@ type GenerateRequest struct {
|
||||||
// Prompt is the textual prompt to send to the model.
|
// Prompt is the textual prompt to send to the model.
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
// Suffix is the text that comes after the inserted text.
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
|
|
||||||
// System overrides the model's default system message/prompt.
|
// System overrides the model's default system message/prompt.
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
|
||||||
|
|
|
@ -34,13 +34,19 @@ import (
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errCapabilityCompletion = errors.New("completion")
|
var (
|
||||||
|
errCapabilities = errors.New("does not support")
|
||||||
|
errCapabilityCompletion = errors.New("completion")
|
||||||
|
errCapabilityTools = errors.New("tools")
|
||||||
|
errCapabilityInsert = errors.New("insert")
|
||||||
|
)
|
||||||
|
|
||||||
type Capability string
|
type Capability string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CapabilityCompletion = Capability("completion")
|
CapabilityCompletion = Capability("completion")
|
||||||
CapabilityTools = Capability("tools")
|
CapabilityTools = Capability("tools")
|
||||||
|
CapabilityInsert = Capability("insert")
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
|
@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
}
|
}
|
||||||
case CapabilityTools:
|
case CapabilityTools:
|
||||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||||
errs = append(errs, errors.New("tools"))
|
errs = append(errs, errCapabilityTools)
|
||||||
|
}
|
||||||
|
case CapabilityInsert:
|
||||||
|
vars := m.Template.Vars()
|
||||||
|
if !slices.Contains(vars, "suffix") {
|
||||||
|
errs = append(errs, errCapabilityInsert)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
|
@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := errors.Join(errs...); err != nil {
|
if err := errors.Join(errs...); err != nil {
|
||||||
return fmt.Errorf("does not support %w", errors.Join(errs...))
|
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
|
if req.Suffix != "" {
|
||||||
|
caps = append(caps, CapabilityInsert)
|
||||||
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if errors.Is(err, errCapabilityCompletion) {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
|
@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
prompt := req.Prompt
|
prompt := req.Prompt
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
var msgs []api.Message
|
|
||||||
if req.System != "" {
|
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
||||||
} else if m.System != "" {
|
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, i := range images {
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
||||||
|
|
||||||
tmpl := m.Template
|
tmpl := m.Template
|
||||||
if req.Template != "" {
|
if req.Template != "" {
|
||||||
tmpl, err = template.Parse(req.Template)
|
tmpl, err = template.Parse(req.Template)
|
||||||
|
@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
b.WriteString(s)
|
b.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
var values template.Values
|
||||||
|
if req.Suffix != "" {
|
||||||
|
values.Prompt = prompt
|
||||||
|
values.Suffix = req.Suffix
|
||||||
|
} else {
|
||||||
|
var msgs []api.Message
|
||||||
|
if req.System != "" {
|
||||||
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||||
|
} else if m.System != "" {
|
||||||
|
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, i := range images {
|
||||||
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
|
}
|
||||||
|
|
||||||
|
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpl.Execute(&b, values); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errRequired):
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
c.JSON(499, gin.H{"error": "request canceled"})
|
c.JSON(499, gin.H{"error": "request canceled"})
|
||||||
|
|
|
@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) {
|
||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add 10ms delay to simulate loading
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
|
@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
go s.sched.Run(context.TODO())
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Model: "test",
|
||||||
Modelfile: fmt.Sprintf(`FROM %s
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
TEMPLATE """
|
TEMPLATE """
|
||||||
{{- if .System }}System: {{ .System }} {{ end }}
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("missing capabilities", func(t *testing.T) {
|
t.Run("missing capabilities chat", func(t *testing.T) {
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "bert",
|
Model: "bert",
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
"general.architecture": "bert",
|
"general.architecture": "bert",
|
||||||
"bert.pooling_type": uint32(0),
|
"bert.pooling_type": uint32(0),
|
||||||
|
@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if actual.TotalDuration == 0 {
|
if actual.TotalDuration == 0 {
|
||||||
t.Errorf("expected load duration > 0, got 0")
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) {
|
||||||
go s.sched.Run(context.TODO())
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Model: "test",
|
||||||
Modelfile: fmt.Sprintf(`FROM %s
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
TEMPLATE """
|
TEMPLATE """
|
||||||
{{- if .System }}System: {{ .System }} {{ end }}
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("missing capabilities", func(t *testing.T) {
|
t.Run("missing capabilities generate", func(t *testing.T) {
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "bert",
|
Model: "bert",
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
"general.architecture": "bert",
|
"general.architecture": "bert",
|
||||||
"bert.pooling_type": uint32(0),
|
"bert.pooling_type": uint32(0),
|
||||||
|
@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("load model", func(t *testing.T) {
|
t.Run("load model", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
Model: "test",
|
Model: "test",
|
||||||
|
@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if actual.TotalDuration == 0 {
|
if actual.TotalDuration == 0 {
|
||||||
t.Errorf("expected load duration > 0, got 0")
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) {
|
||||||
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Modelfile: `FROM test
|
||||||
|
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}"""`,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt without suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("raw", func(t *testing.T) {
|
t.Run("raw", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
Model: "test-system",
|
Model: "test-system",
|
||||||
|
|
|
@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
Tools []api.Tool
|
Tools []api.Tool
|
||||||
|
Prompt string
|
||||||
|
Suffix string
|
||||||
|
|
||||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||||
forceLegacy bool
|
forceLegacy bool
|
||||||
|
@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system, messages := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
if v.Prompt != "" && v.Suffix != "" {
|
||||||
|
return t.Template.Execute(w, map[string]any{
|
||||||
|
"Prompt": v.Prompt,
|
||||||
|
"Suffix": v.Suffix,
|
||||||
|
"Response": "",
|
||||||
|
})
|
||||||
|
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
|
|
|
@ -359,3 +359,38 @@ Answer: `,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithSuffix(t *testing.T) {
|
||||||
|
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
values Values
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue