From e8b954c646544d40d84be50aae9cd909fcbd8f41 Mon Sep 17 00:00:00 2001 From: Josh <76125168+joshyan1@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:24:29 -0700 Subject: [PATCH] server: validate template (#5734) add template validation to modelfile --- server/images.go | 6 ++++++ server/routes.go | 14 +++++++++++--- server/routes_create_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/server/images.go b/server/images.go index 5e4e8858..574dec19 100644 --- a/server/images.go +++ b/server/images.go @@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio layers = append(layers, baseLayer.Layer) } case "license", "template", "system": + if c.Name == "template" { + if _, err := template.Parse(c.Args); err != nil { + return fmt.Errorf("%w: %s", errBadTemplate, err) + } + } + if c.Name != "license" { // replace layers = slices.DeleteFunc(layers, func(layer *Layer) bool { diff --git a/server/routes.go b/server/routes.go index c33b7195..85db7924 100644 --- a/server/routes.go +++ b/server/routes.go @@ -56,6 +56,7 @@ func init() { } var errRequired = errors.New("is required") +var errBadTemplate = errors.New("template error") func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() @@ -609,8 +610,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) { quantization := cmp.Or(r.Quantize, r.Quantization) if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { + if errors.Is(err, errBadTemplate) { + ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} + } ch <- gin.H{"error": err.Error()} - } + } }() if r.Stream != nil && !*r.Stream { @@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) { return } case gin.H: + status, ok := r["status"].(int) + if !ok { + status = http.StatusInternalServerError + } if errorMsg, ok := r["error"].(string); ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) + c.JSON(status, gin.H{"error": errorMsg}) return } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"}) + c.JSON(status, gin.H{"error": "unexpected error format in progress response"}) return } default: diff --git a/server/routes_create_test.go b/server/routes_create_test.go index cb548ebd..3234ea5e 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) { if string(system) != "Say bye!" { t.Errorf("expected \"Say bye!\", actual %s", system) } + + t.Run("incomplete template", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + + t.Run("template with unclosed if", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + + t.Run("template with undefined function", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) } func TestCreateLicenses(t *testing.T) {