server: allow mixed-case model names on push, pull, cp, and create (#7676)
This change allows for mixed-case model names to be pushed, pulled, copied, and created, which was previously disallowed because the Ollama registry was backed by a Docker registry that enforced a naming convention that disallowed mixed-case names, which is no longer the case. This does not break existing, intended, behaviors. Also, make TestCase test a story of creating, updating, pulling, and copying a model with case variations, ensuring the model's manifest is updated correctly, and not duplicated across different files with different case variations.
This commit is contained in:
parent
e66c29261a
commit
4b8a2e341a
4 changed files with 169 additions and 84 deletions
|
@ -13,6 +13,7 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -1071,6 +1072,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||
// makeRequest. It can be used to resolve hosts in model names to local
|
||||
// addresses for testing. For example, the model name ("example.com/my/model")
|
||||
// can be directed to push/pull from "127.0.0.1:1234".
|
||||
//
|
||||
// This is not safe to set across goroutines. It should be set in
|
||||
// the main test goroutine, and not by tests marked to run in parallel with
|
||||
// t.Parallel().
|
||||
//
|
||||
// It should be cleared after use, otherwise it will affect other tests.
|
||||
//
|
||||
// Ideally we would have some set this up the stack, but the code is not
|
||||
// structured in a way that makes this easy, so this will have to do for now.
|
||||
var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
||||
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
requestURL.Scheme = "http"
|
||||
|
@ -1105,6 +1121,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
|||
}
|
||||
|
||||
resp, err := (&http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: testMakeRequestDialContext,
|
||||
},
|
||||
CheckRedirect: regOpts.CheckRedirect,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
|
|
|
@ -540,7 +540,8 @@ func (s *Server) PullHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := checkNameExists(name); err != nil {
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
@ -621,19 +622,20 @@ func (s *Server) PushHandler(c *gin.Context) {
|
|||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func checkNameExists(name model.Name) error {
|
||||
names, err := Manifests(true)
|
||||
// getExistingName returns the original, on disk name if the input name is a
|
||||
// case-insensitive match, otherwise it returns the input name.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
if err != nil {
|
||||
return err
|
||||
return zero, err
|
||||
}
|
||||
|
||||
for n := range names {
|
||||
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
|
||||
return errors.New("a model with that name already exists")
|
||||
for e := range existing {
|
||||
if n.EqualFold(e) {
|
||||
return e, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
@ -652,7 +654,8 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := checkNameExists(name); err != nil {
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
@ -958,14 +961,19 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
|||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
|
||||
return
|
||||
}
|
||||
src, err := getExistingName(src)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
dst := model.ParseName(r.Destination)
|
||||
if !dst.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
|
||||
return
|
||||
}
|
||||
|
||||
if err := checkNameExists(dst); err != nil {
|
||||
dst, err = getExistingName(dst)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -7,13 +7,18 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
|
@ -473,83 +478,129 @@ func Test_Routes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCase(t *testing.T) {
|
||||
func casingShuffle(s string) string {
|
||||
rr := []rune(s)
|
||||
for i := range rr {
|
||||
if rand.N(2) == 0 {
|
||||
rr[i] = unicode.ToUpper(rr[i])
|
||||
} else {
|
||||
rr[i] = unicode.ToLower(rr[i])
|
||||
}
|
||||
}
|
||||
return string(rr)
|
||||
}
|
||||
|
||||
func TestManifestCaseSensitivity(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
cases := []string{
|
||||
"mistral",
|
||||
"llama3:latest",
|
||||
"library/phi3:q4_0",
|
||||
"registry.ollama.ai/library/gemma:q5_K_M",
|
||||
// TODO: host:port currently fails on windows (#4107)
|
||||
// "localhost:5000/alice/bob:latest",
|
||||
r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
io.WriteString(w, `{}`) //nolint:errcheck
|
||||
}))
|
||||
defer r.Close()
|
||||
|
||||
nameUsed := make(map[string]bool)
|
||||
name := func() string {
|
||||
const fqmn = "example/namespace/model:tag"
|
||||
for {
|
||||
v := casingShuffle(fqmn)
|
||||
if nameUsed[v] {
|
||||
continue
|
||||
}
|
||||
nameUsed[v] = true
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
wantStableName := name()
|
||||
|
||||
// checkManifestList tests that there is strictly one manifest in the
|
||||
// models directory, and that the manifest is for the model under test.
|
||||
checkManifestList := func() {
|
||||
t.Helper()
|
||||
|
||||
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
|
||||
var entries []string
|
||||
t.Logf("dir entries:")
|
||||
fsys := os.DirFS(mandir)
|
||||
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Logf(" %s", fs.FormatDirEntry(info))
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
path = strings.TrimPrefix(path, mandir)
|
||||
entries = append(entries, path)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to walk directory: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("len(got) = %d, want 1", len(entries))
|
||||
return // do not use Fatal so following steps run
|
||||
}
|
||||
|
||||
g := entries[0] // raw path
|
||||
g = filepath.ToSlash(g)
|
||||
w := model.ParseName(wantStableName).Filepath()
|
||||
w = filepath.ToSlash(w)
|
||||
if g != w {
|
||||
t.Errorf("\ngot: %s\nwant: %s", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
checkOK := func(w *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("code = %d, want 200", w.Code)
|
||||
t.Logf("body: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
var s Server
|
||||
for _, tt := range cases {
|
||||
t.Run(tt, func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: tt,
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200 got %d", w.Code)
|
||||
}
|
||||
|
||||
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("create", func(t *testing.T) {
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: strings.ToUpper(tt),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 500 got %d", w.Code)
|
||||
}
|
||||
|
||||
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pull", func(t *testing.T) {
|
||||
w := createRequest(t, s.PullHandler, api.PullRequest{
|
||||
Name: strings.ToUpper(tt),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 500 got %d", w.Code)
|
||||
}
|
||||
|
||||
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("copy", func(t *testing.T) {
|
||||
w := createRequest(t, s.CopyHandler, api.CopyRequest{
|
||||
Source: tt,
|
||||
Destination: strings.ToUpper(tt),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 500 got %d", w.Code)
|
||||
}
|
||||
|
||||
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
|
||||
}
|
||||
t.Cleanup(func() { testMakeRequestDialContext = nil })
|
||||
|
||||
t.Logf("creating")
|
||||
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
// Start with the stable name, and later use a case-shuffled
|
||||
// version.
|
||||
Name: wantStableName,
|
||||
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
t.Logf("creating (again)")
|
||||
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: name(),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
t.Logf("pulling")
|
||||
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
|
||||
Name: name(),
|
||||
Stream: &stream,
|
||||
Insecure: true,
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
t.Logf("copying")
|
||||
checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
|
||||
Source: name(),
|
||||
Destination: name(),
|
||||
}))
|
||||
checkManifestList()
|
||||
}
|
||||
|
||||
func TestShow(t *testing.T) {
|
||||
|
|
|
@ -298,6 +298,13 @@ func (n Name) LogValue() slog.Value {
|
|||
return slog.StringValue(n.String())
|
||||
}
|
||||
|
||||
func (n Name) EqualFold(o Name) bool {
|
||||
return strings.EqualFold(n.Host, o.Host) &&
|
||||
strings.EqualFold(n.Namespace, o.Namespace) &&
|
||||
strings.EqualFold(n.Model, o.Model) &&
|
||||
strings.EqualFold(n.Tag, o.Tag)
|
||||
}
|
||||
|
||||
func isValidLen(kind partKind, s string) bool {
|
||||
switch kind {
|
||||
case kindHost:
|
||||
|
|
Loading…
Add table
Reference in a new issue