From 4b8a2e341a9b4e713180b483f42316665c5faea3 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 19 Nov 2024 15:05:57 -0800 Subject: [PATCH] server: allow mixed-case model names on push, pull, cp, and create (#7676) This change allows for mixed-case model names to be pushed, pulled, copied, and created, which was previously disallowed because the Ollama registry was backed by a Docker registry that enforced a naming convention that disallowed mixed-case names, which is no longer the case. This does not break existing, intended, behaviors. Also, make TestCase test a story of creating, updating, pulling, and copying a model with case variations, ensuring the model's manifest is updated correctly, and not duplicated across different files with different case variations. --- server/images.go | 19 +++++ server/routes.go | 34 +++++--- server/routes_test.go | 193 ++++++++++++++++++++++++++---------------- types/model/name.go | 7 ++ 4 files changed, 169 insertions(+), 84 deletions(-) diff --git a/server/images.go b/server/images.go index 584b7b13..6a0e8ae3 100644 --- a/server/images.go +++ b/server/images.go @@ -13,6 +13,7 @@ import ( "io" "log" "log/slog" + "net" "net/http" "net/url" "os" @@ -1071,6 +1072,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR return nil, errUnauthorized } +// testMakeRequestDialContext specifies the dial function for the http client in +// makeRequest. It can be used to resolve hosts in model names to local +// addresses for testing. For example, the model name ("example.com/my/model") +// can be directed to push/pull from "127.0.0.1:1234". +// +// This is not safe to set across goroutines. It should be set in +// the main test goroutine, and not by tests marked to run in parallel with +// t.Parallel(). +// +// It should be cleared after use, otherwise it will affect other tests. +// +// Ideally we would have some set this up the stack, but the code is not +// structured in a way that makes this easy, so this will have to do for now. +var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) { if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { requestURL.Scheme = "http" @@ -1105,6 +1121,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header } resp, err := (&http.Client{ + Transport: &http.Transport{ + DialContext: testMakeRequestDialContext, + }, CheckRedirect: regOpts.CheckRedirect, }).Do(req) if err != nil { diff --git a/server/routes.go b/server/routes.go index c5fd3293..f5b05bb5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -540,7 +540,8 @@ func (s *Server) PullHandler(c *gin.Context) { return } - if err := checkNameExists(name); err != nil { + name, err = getExistingName(name) + if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -621,19 +622,20 @@ func (s *Server) PushHandler(c *gin.Context) { streamResponse(c, ch) } -func checkNameExists(name model.Name) error { - names, err := Manifests(true) +// getExistingName returns the original, on disk name if the input name is a +// case-insensitive match, otherwise it returns the input name. +func getExistingName(n model.Name) (model.Name, error) { + var zero model.Name + existing, err := Manifests(true) if err != nil { - return err + return zero, err } - - for n := range names { - if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name { - return errors.New("a model with that name already exists") + for e := range existing { + if n.EqualFold(e) { + return e, nil } } - - return nil + return n, nil } func (s *Server) CreateHandler(c *gin.Context) { @@ -652,7 +654,8 @@ func (s *Server) CreateHandler(c *gin.Context) { return } - if err := checkNameExists(name); err != nil { + name, err := getExistingName(name) + if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -958,14 +961,19 @@ func (s *Server) CopyHandler(c *gin.Context) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)}) return } + src, err := getExistingName(src) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } dst := model.ParseName(r.Destination) if !dst.IsValid() { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)}) return } - - if err := checkNameExists(dst); err != nil { + dst, err = getExistingName(dst) + if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } diff --git a/server/routes_test.go b/server/routes_test.go index bd5b56af..1daf36f1 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -7,13 +7,18 @@ import ( "encoding/json" "fmt" "io" + "io/fs" "math" + "math/rand/v2" + "net" "net/http" "net/http/httptest" "os" + "path/filepath" "sort" "strings" "testing" + "unicode" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" @@ -473,83 +478,129 @@ func Test_Routes(t *testing.T) { } } -func TestCase(t *testing.T) { +func casingShuffle(s string) string { + rr := []rune(s) + for i := range rr { + if rand.N(2) == 0 { + rr[i] = unicode.ToUpper(rr[i]) + } else { + rr[i] = unicode.ToLower(rr[i]) + } + } + return string(rr) +} + +func TestManifestCaseSensitivity(t *testing.T) { t.Setenv("OLLAMA_MODELS", t.TempDir()) - cases := []string{ - "mistral", - "llama3:latest", - "library/phi3:q4_0", - "registry.ollama.ai/library/gemma:q5_K_M", - // TODO: host:port currently fails on windows (#4107) - // "localhost:5000/alice/bob:latest", + r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + io.WriteString(w, `{}`) //nolint:errcheck + })) + defer r.Close() + + nameUsed := make(map[string]bool) + name := func() string { + const fqmn = "example/namespace/model:tag" + for { + v := casingShuffle(fqmn) + if nameUsed[v] { + continue + } + nameUsed[v] = true + return v + } + } + + wantStableName := name() + + // checkManifestList tests that there is strictly one manifest in the + // models directory, and that the manifest is for the model under test. + checkManifestList := func() { + t.Helper() + + mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/") + var entries []string + t.Logf("dir entries:") + fsys := os.DirFS(mandir) + err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error { + if err != nil { + return err + } + t.Logf(" %s", fs.FormatDirEntry(info)) + if info.IsDir() { + return nil + } + path = strings.TrimPrefix(path, mandir) + entries = append(entries, path) + return nil + }) + if err != nil { + t.Fatalf("failed to walk directory: %v", err) + } + + if len(entries) != 1 { + t.Errorf("len(got) = %d, want 1", len(entries)) + return // do not use Fatal so following steps run + } + + g := entries[0] // raw path + g = filepath.ToSlash(g) + w := model.ParseName(wantStableName).Filepath() + w = filepath.ToSlash(w) + if g != w { + t.Errorf("\ngot: %s\nwant: %s", g, w) + } + } + + checkOK := func(w *httptest.ResponseRecorder) { + t.Helper() + if w.Code != http.StatusOK { + t.Errorf("code = %d, want 200", w.Code) + t.Logf("body: %s", w.Body.String()) + } } var s Server - for _, tt := range cases { - t.Run(tt, func(t *testing.T) { - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Name: tt, - Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), - Stream: &stream, - }) - - if w.Code != http.StatusOK { - t.Fatalf("expected status 200 got %d", w.Code) - } - - expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"}) - if err != nil { - t.Fatal(err) - } - - t.Run("create", func(t *testing.T) { - w = createRequest(t, s.CreateHandler, api.CreateRequest{ - Name: strings.ToUpper(tt), - Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), - Stream: &stream, - }) - - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status 500 got %d", w.Code) - } - - if !bytes.Equal(w.Body.Bytes(), expect) { - t.Fatalf("expected error %s got %s", expect, w.Body.String()) - } - }) - - t.Run("pull", func(t *testing.T) { - w := createRequest(t, s.PullHandler, api.PullRequest{ - Name: strings.ToUpper(tt), - Stream: &stream, - }) - - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status 500 got %d", w.Code) - } - - if !bytes.Equal(w.Body.Bytes(), expect) { - t.Fatalf("expected error %s got %s", expect, w.Body.String()) - } - }) - - t.Run("copy", func(t *testing.T) { - w := createRequest(t, s.CopyHandler, api.CopyRequest{ - Source: tt, - Destination: strings.ToUpper(tt), - }) - - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status 500 got %d", w.Code) - } - - if !bytes.Equal(w.Body.Bytes(), expect) { - t.Fatalf("expected error %s got %s", expect, w.Body.String()) - } - }) - }) + testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "tcp", r.Listener.Addr().String()) } + t.Cleanup(func() { testMakeRequestDialContext = nil }) + + t.Logf("creating") + checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{ + // Start with the stable name, and later use a case-shuffled + // version. + Name: wantStableName, + + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), + Stream: &stream, + })) + checkManifestList() + + t.Logf("creating (again)") + checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: name(), + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), + Stream: &stream, + })) + checkManifestList() + + t.Logf("pulling") + checkOK(createRequest(t, s.PullHandler, api.PullRequest{ + Name: name(), + Stream: &stream, + Insecure: true, + })) + checkManifestList() + + t.Logf("copying") + checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{ + Source: name(), + Destination: name(), + })) + checkManifestList() } func TestShow(t *testing.T) { diff --git a/types/model/name.go b/types/model/name.go index 75b35ef7..9d819f10 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -298,6 +298,13 @@ func (n Name) LogValue() slog.Value { return slog.StringValue(n.String()) } +func (n Name) EqualFold(o Name) bool { + return strings.EqualFold(n.Host, o.Host) && + strings.EqualFold(n.Namespace, o.Namespace) && + strings.EqualFold(n.Model, o.Model) && + strings.EqualFold(n.Tag, o.Tag) +} + func isValidLen(kind partKind, s string) bool { switch kind { case kindHost: