From 85a57006d1fe8655151814ef29d44907925a0a3d Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 13 May 2024 15:27:51 -0700 Subject: [PATCH] check if name exists before create/pull/copy --- server/routes.go | 42 ++++++++++++--- server/routes_test.go | 123 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 135 insertions(+), 30 deletions(-) diff --git a/server/routes.go b/server/routes.go index 14853feb..c2907644 100644 --- a/server/routes.go +++ b/server/routes.go @@ -420,13 +420,14 @@ func (s *Server) PullModelHandler(c *gin.Context) { return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + name := model.ParseName(cmp.Or(req.Model, req.Name)) + if !name.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) + return + } + + if err := checkNameExists(name); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -444,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := PullModel(ctx, model, regOpts, fn); err != nil { + if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -506,6 +507,21 @@ func (s *Server) PushModelHandler(c *gin.Context) { streamResponse(c, ch) } +func checkNameExists(name model.Name) error { + names, err := Manifests() + if err != nil { + return err + } + + for n := range names { + if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name { + return fmt.Errorf("a model with that name already exists") + } + } + + return nil +} + func (s *Server) CreateModelHandler(c *gin.Context) { var req api.CreateRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { @@ -522,6 +538,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } + if err := checkNameExists(name); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Path == "" && req.Modelfile == "" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) return @@ -770,6 +791,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) { return } + if err := checkNameExists(dst); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)}) } else if err != nil { diff --git a/server/routes_test.go b/server/routes_test.go index e144c957..100db3a6 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -21,6 +21,28 @@ import ( "github.com/ollama/ollama/version" ) +func createTestFile(t *testing.T, name string) string { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), name) + assert.Nil(t, err) + defer f.Close() + + err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) + assert.Nil(t, err) + + err = binary.Write(f, binary.LittleEndian, uint32(3)) + assert.Nil(t, err) + + err = binary.Write(f, binary.LittleEndian, uint64(0)) + assert.Nil(t, err) + + err = binary.Write(f, binary.LittleEndian, uint64(0)) + assert.Nil(t, err) + + return f.Name() +} + func Test_Routes(t *testing.T) { type testCase struct { Name string @@ -30,28 +52,6 @@ func Test_Routes(t *testing.T) { Expected func(t *testing.T, resp *http.Response) } - createTestFile := func(t *testing.T, name string) string { - t.Helper() - - f, err := os.CreateTemp(t.TempDir(), name) - assert.Nil(t, err) - defer f.Close() - - err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) - assert.Nil(t, err) - - err = binary.Write(f, binary.LittleEndian, uint32(3)) - assert.Nil(t, err) - - err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.Nil(t, err) - - err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.Nil(t, err) - - return f.Name() - } - createTestModel := func(t *testing.T, name string) { fname := createTestFile(t, "ollama-model") @@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) { }) } } + +func TestCase(t *testing.T) { + t.Setenv("OLLAMA_MODELS", t.TempDir()) + + cases := []string{ + "mistral", + "llama3:latest", + "library/phi3:q4_0", + "registry.ollama.ai/library/gemma:q5_K_M", + // TODO: host:port currently fails on windows (#4107) + // "localhost:5000/alice/bob:latest", + } + + var s Server + for _, tt := range cases { + t.Run(tt, func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: tt, + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200 got %d", w.Code) + } + + expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"}) + if err != nil { + t.Fatal(err) + } + + t.Run("create", func(t *testing.T) { + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: strings.ToUpper(tt), + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 500 got %d", w.Code) + } + + if !bytes.Equal(w.Body.Bytes(), expect) { + t.Fatalf("expected error %s got %s", expect, w.Body.String()) + } + }) + + t.Run("pull", func(t *testing.T) { + w := createRequest(t, s.PullModelHandler, api.PullRequest{ + Name: strings.ToUpper(tt), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 500 got %d", w.Code) + } + + if !bytes.Equal(w.Body.Bytes(), expect) { + t.Fatalf("expected error %s got %s", expect, w.Body.String()) + } + }) + + t.Run("copy", func(t *testing.T) { + w := createRequest(t, s.CopyModelHandler, api.CopyRequest{ + Source: tt, + Destination: strings.ToUpper(tt), + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 500 got %d", w.Code) + } + + if !bytes.Equal(w.Body.Bytes(), expect) { + t.Fatalf("expected error %s got %s", expect, w.Body.String()) + } + }) + }) + } +}