Merge pull request #4413 from ollama/mxyng/name-check
check if name exists before create/pull/copy
This commit is contained in:
commit
96bc232b43
2 changed files with 135 additions and 30 deletions
|
@ -421,13 +421,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var model string
|
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||||
if req.Model != "" {
|
if !name.IsValid() {
|
||||||
model = req.Model
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
|
||||||
} else if req.Name != "" {
|
return
|
||||||
model = req.Name
|
}
|
||||||
} else {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
if err := checkNameExists(name); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -445,7 +446,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
|
||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := PullModel(ctx, model, regOpts, fn); err != nil {
|
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -507,6 +508,21 @@ func (s *Server) PushModelHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkNameExists(name model.Name) error {
|
||||||
|
names, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := range names {
|
||||||
|
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
|
||||||
|
return fmt.Errorf("a model with that name already exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) CreateModelHandler(c *gin.Context) {
|
func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
var req api.CreateRequest
|
var req api.CreateRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
|
@ -523,6 +539,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := checkNameExists(name); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if req.Path == "" && req.Modelfile == "" {
|
if req.Path == "" && req.Modelfile == "" {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
||||||
return
|
return
|
||||||
|
@ -771,6 +792,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := checkNameExists(dst); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
|
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
|
@ -21,16 +21,7 @@ import (
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Routes(t *testing.T) {
|
func createTestFile(t *testing.T, name string) string {
|
||||||
type testCase struct {
|
|
||||||
Name string
|
|
||||||
Method string
|
|
||||||
Path string
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, resp *http.Response)
|
|
||||||
}
|
|
||||||
|
|
||||||
createTestFile := func(t *testing.T, name string) string {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), name)
|
f, err := os.CreateTemp(t.TempDir(), name)
|
||||||
|
@ -50,6 +41,15 @@ func Test_Routes(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
return f.Name()
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Routes(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, resp *http.Response)
|
||||||
}
|
}
|
||||||
|
|
||||||
createTestModel := func(t *testing.T, name string) {
|
createTestModel := func(t *testing.T, name string) {
|
||||||
|
@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCase(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
cases := []string{
|
||||||
|
"mistral",
|
||||||
|
"llama3:latest",
|
||||||
|
"library/phi3:q4_0",
|
||||||
|
"registry.ollama.ai/library/gemma:q5_K_M",
|
||||||
|
// TODO: host:port currently fails on windows (#4107)
|
||||||
|
// "localhost:5000/alice/bob:latest",
|
||||||
|
}
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt, func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: tt,
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("create", func(t *testing.T) {
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: strings.ToUpper(tt),
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 500 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||||
|
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pull", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.PullModelHandler, api.PullRequest{
|
||||||
|
Name: strings.ToUpper(tt),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 500 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||||
|
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("copy", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
|
||||||
|
Source: tt,
|
||||||
|
Destination: strings.ToUpper(tt),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 500 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||||
|
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue