server: allow mixed-case model names on push, pull, cp, and create (#7676)

This change allows for mixed-case model names to be pushed, pulled,
copied, and created, which was previously disallowed because the Ollama
registry was backed by a Docker registry that enforced a naming
convention that disallowed mixed-case names, which is no longer the
case.

This does not break existing, intended, behaviors.

Also, make TestCase test a story of creating, updating, pulling, and
copying a model with case variations, ensuring the model's manifest is
updated correctly, and not duplicated across different files with
different case variations.
This commit is contained in:
Blake Mizerany 2024-11-19 15:05:57 -08:00 committed by GitHub
parent e66c29261a
commit 4b8a2e341a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 169 additions and 84 deletions

View file

@ -13,6 +13,7 @@ import (
"io" "io"
"log" "log"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -1071,6 +1072,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, errUnauthorized return nil, errUnauthorized
} }
// testMakeRequestDialContext specifies the dial function for the http client in
// makeRequest. It can be used to resolve hosts in model names to local
// addresses for testing. For example, the model name ("example.com/my/model")
// can be directed to push/pull from "127.0.0.1:1234".
//
// This is not safe to set across goroutines. It should be set in
// the main test goroutine, and not by tests marked to run in parallel with
// t.Parallel().
//
// It should be cleared after use, otherwise it will affect other tests.
//
// Ideally we would have some set this up the stack, but the code is not
// structured in a way that makes this easy, so this will have to do for now.
var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) { func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http" requestURL.Scheme = "http"
@ -1105,6 +1121,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
} }
resp, err := (&http.Client{ resp, err := (&http.Client{
Transport: &http.Transport{
DialContext: testMakeRequestDialContext,
},
CheckRedirect: regOpts.CheckRedirect, CheckRedirect: regOpts.CheckRedirect,
}).Do(req) }).Do(req)
if err != nil { if err != nil {

View file

@ -540,7 +540,8 @@ func (s *Server) PullHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(name); err != nil { name, err = getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -621,19 +622,20 @@ func (s *Server) PushHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func checkNameExists(name model.Name) error { // getExistingName returns the original, on disk name if the input name is a
names, err := Manifests(true) // case-insensitive match, otherwise it returns the input name.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
if err != nil { if err != nil {
return err return zero, err
} }
for e := range existing {
for n := range names { if n.EqualFold(e) {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name { return e, nil
return errors.New("a model with that name already exists")
} }
} }
return n, nil
return nil
} }
func (s *Server) CreateHandler(c *gin.Context) { func (s *Server) CreateHandler(c *gin.Context) {
@ -652,7 +654,8 @@ func (s *Server) CreateHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(name); err != nil { name, err := getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -958,14 +961,19 @@ func (s *Server) CopyHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
return return
} }
src, err := getExistingName(src)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
dst := model.ParseName(r.Destination) dst := model.ParseName(r.Destination)
if !dst.IsValid() { if !dst.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
return return
} }
dst, err = getExistingName(dst)
if err := checkNameExists(dst); err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }

View file

@ -7,13 +7,18 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/fs"
"math" "math"
"math/rand/v2"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"testing" "testing"
"unicode"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -473,83 +478,129 @@ func Test_Routes(t *testing.T) {
} }
} }
func TestCase(t *testing.T) { func casingShuffle(s string) string {
rr := []rune(s)
for i := range rr {
if rand.N(2) == 0 {
rr[i] = unicode.ToUpper(rr[i])
} else {
rr[i] = unicode.ToLower(rr[i])
}
}
return string(rr)
}
func TestManifestCaseSensitivity(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{ r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
"mistral", w.WriteHeader(http.StatusOK)
"llama3:latest", io.WriteString(w, `{}`) //nolint:errcheck
"library/phi3:q4_0", }))
"registry.ollama.ai/library/gemma:q5_K_M", defer r.Close()
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest", nameUsed := make(map[string]bool)
name := func() string {
const fqmn = "example/namespace/model:tag"
for {
v := casingShuffle(fqmn)
if nameUsed[v] {
continue
}
nameUsed[v] = true
return v
}
}
wantStableName := name()
// checkManifestList tests that there is strictly one manifest in the
// models directory, and that the manifest is for the model under test.
checkManifestList := func() {
t.Helper()
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
var entries []string
t.Logf("dir entries:")
fsys := os.DirFS(mandir)
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
t.Logf(" %s", fs.FormatDirEntry(info))
if info.IsDir() {
return nil
}
path = strings.TrimPrefix(path, mandir)
entries = append(entries, path)
return nil
})
if err != nil {
t.Fatalf("failed to walk directory: %v", err)
}
if len(entries) != 1 {
t.Errorf("len(got) = %d, want 1", len(entries))
return // do not use Fatal so following steps run
}
g := entries[0] // raw path
g = filepath.ToSlash(g)
w := model.ParseName(wantStableName).Filepath()
w = filepath.ToSlash(w)
if g != w {
t.Errorf("\ngot: %s\nwant: %s", g, w)
}
}
checkOK := func(w *httptest.ResponseRecorder) {
t.Helper()
if w.Code != http.StatusOK {
t.Errorf("code = %d, want 200", w.Code)
t.Logf("body: %s", w.Body.String())
}
} }
var s Server var s Server
for _, tt := range cases { testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
t.Run(tt, func(t *testing.T) { var d net.Dialer
w := createRequest(t, s.CreateHandler, api.CreateRequest{ return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
} }
t.Cleanup(func() { testMakeRequestDialContext = nil })
t.Logf("creating")
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
// Start with the stable name, and later use a case-shuffled
// version.
Name: wantStableName,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
}))
checkManifestList()
t.Logf("creating (again)")
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
Name: name(),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
}))
checkManifestList()
t.Logf("pulling")
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
Name: name(),
Stream: &stream,
Insecure: true,
}))
checkManifestList()
t.Logf("copying")
checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
Source: name(),
Destination: name(),
}))
checkManifestList()
} }
func TestShow(t *testing.T) { func TestShow(t *testing.T) {

View file

@ -298,6 +298,13 @@ func (n Name) LogValue() slog.Value {
return slog.StringValue(n.String()) return slog.StringValue(n.String())
} }
func (n Name) EqualFold(o Name) bool {
return strings.EqualFold(n.Host, o.Host) &&
strings.EqualFold(n.Namespace, o.Namespace) &&
strings.EqualFold(n.Model, o.Model) &&
strings.EqualFold(n.Tag, o.Tag)
}
func isValidLen(kind partKind, s string) bool { func isValidLen(kind partKind, s string) bool {
switch kind { switch kind {
case kindHost: case kindHost: