From d290e87513664be8ca3120348614d124991ccb86 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jun 2024 19:13:36 -0700 Subject: [PATCH] add suffix support to generate endpoint this change is triggered by the presence of "suffix", particularly useful for code completion tasks --- api/types.go | 3 ++ server/images.go | 17 ++++++-- server/routes.go | 40 +++++++++++------- server/routes_generate_test.go | 77 ++++++++++++++++++++++++++++++---- template/template.go | 10 ++++- template/template_test.go | 35 ++++++++++++++++ 6 files changed, 155 insertions(+), 27 deletions(-) diff --git a/api/types.go b/api/types.go index e670d114..3029fca8 100644 --- a/api/types.go +++ b/api/types.go @@ -47,6 +47,9 @@ type GenerateRequest struct { // Prompt is the textual prompt to send to the model. Prompt string `json:"prompt"` + // Suffix is the text that comes after the inserted text. + Suffix string `json:"suffix"` + // System overrides the model's default system message/prompt. System string `json:"system"` diff --git a/server/images.go b/server/images.go index 1b87888e..5e4e8858 100644 --- a/server/images.go +++ b/server/images.go @@ -34,13 +34,19 @@ import ( "github.com/ollama/ollama/version" ) -var errCapabilityCompletion = errors.New("completion") +var ( + errCapabilities = errors.New("does not support") + errCapabilityCompletion = errors.New("completion") + errCapabilityTools = errors.New("tools") + errCapabilityInsert = errors.New("insert") +) type Capability string const ( CapabilityCompletion = Capability("completion") CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") ) type registryOptions struct { @@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } case CapabilityTools: if !slices.Contains(m.Template.Vars(), "tools") { - errs = append(errs, errors.New("tools")) + errs = append(errs, errCapabilityTools) + } + case CapabilityInsert: + vars := m.Template.Vars() + if !slices.Contains(vars, "suffix") { + errs = append(errs, errCapabilityInsert) } default: slog.Error("unknown capability", "capability", cap) @@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } if err := errors.Join(errs...); err != nil { - return fmt.Errorf("does not support %w", errors.Join(errs...)) + return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } return nil diff --git a/server/routes.go b/server/routes.go index d22a099a..c7f74fa4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} + if req.Suffix != "" { + caps = append(caps, CapabilityInsert) + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) @@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt := req.Prompt if !req.Raw { - var msgs []api.Message - if req.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: req.System}) - } else if m.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: m.System}) - } - - for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) - } - - msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) - tmpl := m.Template if req.Template != "" { tmpl, err = template.Parse(req.Template) @@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { b.WriteString(s) } - if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil { + var values template.Values + if req.Suffix != "" { + values.Prompt = prompt + values.Suffix = req.Suffix + } else { + var msgs []api.Message + if req.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: req.System}) + } else if m.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: m.System}) + } + + for _, i := range images { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + } + + values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) + } + + if err := tmpl.Execute(&b, values); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) { func handleScheduleError(c *gin.Context, name string, err error) { switch { - case errors.Is(err, errRequired): + case errors.Is(err, errCapabilities), errors.Is(err, errRequired): c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 9d899328..c914b300 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) { getCpuFn: gpu.GetCPUInfo, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + // add 10ms delay to simulate loading + time.Sleep(10 * time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } @@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) { go s.sched.Run(context.TODO()) w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ {{- if .System }}System: {{ .System }} {{ end }} @@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) { } }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing capabilities chat", func(t *testing.T) { w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "bert", + Model: "bert", Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ "general.architecture": "bert", "bert.pooling_type": uint32(0), @@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) { } if actual.TotalDuration == 0 { - t.Errorf("expected load duration > 0, got 0") + t.Errorf("expected total duration > 0, got 0") } } @@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) { go s.sched.Run(context.TODO()) w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ {{- if .System }}System: {{ .System }} {{ end }} @@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) { } }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing capabilities generate", func(t *testing.T) { w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "bert", + Model: "bert", Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ "general.architecture": "bert", "bert.pooling_type": uint32(0), @@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) { } }) + t.Run("missing capabilities suffix", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "def add(", + Suffix: " return c", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + t.Run("load model", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ Model: "test", @@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) { } if actual.TotalDuration == 0 { - t.Errorf("expected load duration > 0, got 0") + t.Errorf("expected total duration > 0, got 0") } } @@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) { checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") }) + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-suffix", + Modelfile: `FROM test +TEMPLATE """{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}"""`,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("prompt with suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "
 def add(     return c "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("prompt without suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
 	t.Run("raw", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
 			Model:  "test-system",
diff --git a/template/template.go b/template/template.go
index 7cdb30ef..5330c0fa 100644
--- a/template/template.go
+++ b/template/template.go
@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
 type Values struct {
 	Messages []api.Message
 	Tools    []api.Tool
+	Prompt   string
+	Suffix   string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
 
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
-	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+	if v.Prompt != "" && v.Suffix != "" {
+		return t.Template.Execute(w, map[string]any{
+			"Prompt":   v.Prompt,
+			"Suffix":   v.Suffix,
+			"Response": "",
+		})
+	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
 			"Messages": messages,
diff --git a/template/template_test.go b/template/template_test.go
index c678f1b1..ae0db80b 100644
--- a/template/template_test.go
+++ b/template/template_test.go
@@ -359,3 +359,38 @@ Answer: `,
 		})
 	}
 }
+
+func TestExecuteWithSuffix(t *testing.T) {
+	tmpl, err := Parse(`{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cases := []struct {
+		name   string
+		values Values
+		expect string
+	}{
+		{
+			"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
+		},
+		{
+			"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "
 def add( return x ",
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			var b bytes.Buffer
+			if err := tmpl.Execute(&b, tt.values); err != nil {
+				t.Fatal(err)
+			}
+
+			if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}