server: allow mixed-case model names on push, pull, cp, and create (#7676)
This change allows for mixed-case model names to be pushed, pulled, copied, and created, which was previously disallowed because the Ollama registry was backed by a Docker registry that enforced a naming convention that disallowed mixed-case names, which is no longer the case. This does not break existing, intended, behaviors. Also, make TestCase test a story of creating, updating, pulling, and copying a model with case variations, ensuring the model's manifest is updated correctly, and not duplicated across different files with different case variations.
This commit is contained in:
parent
e66c29261a
commit
4b8a2e341a
4 changed files with 169 additions and 84 deletions
|
@ -13,6 +13,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
@ -1071,6 +1072,21 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||||
|
// makeRequest. It can be used to resolve hosts in model names to local
|
||||||
|
// addresses for testing. For example, the model name ("example.com/my/model")
|
||||||
|
// can be directed to push/pull from "127.0.0.1:1234".
|
||||||
|
//
|
||||||
|
// This is not safe to set across goroutines. It should be set in
|
||||||
|
// the main test goroutine, and not by tests marked to run in parallel with
|
||||||
|
// t.Parallel().
|
||||||
|
//
|
||||||
|
// It should be cleared after use, otherwise it will affect other tests.
|
||||||
|
//
|
||||||
|
// Ideally we would have some set this up the stack, but the code is not
|
||||||
|
// structured in a way that makes this easy, so this will have to do for now.
|
||||||
|
var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
||||||
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||||
requestURL.Scheme = "http"
|
requestURL.Scheme = "http"
|
||||||
|
@ -1105,6 +1121,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := (&http.Client{
|
resp, err := (&http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: testMakeRequestDialContext,
|
||||||
|
},
|
||||||
CheckRedirect: regOpts.CheckRedirect,
|
CheckRedirect: regOpts.CheckRedirect,
|
||||||
}).Do(req)
|
}).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -540,7 +540,8 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkNameExists(name); err != nil {
|
name, err = getExistingName(name)
|
||||||
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -621,19 +622,20 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkNameExists(name model.Name) error {
|
// getExistingName returns the original, on disk name if the input name is a
|
||||||
names, err := Manifests(true)
|
// case-insensitive match, otherwise it returns the input name.
|
||||||
|
func getExistingName(n model.Name) (model.Name, error) {
|
||||||
|
var zero model.Name
|
||||||
|
existing, err := Manifests(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return zero, err
|
||||||
}
|
}
|
||||||
|
for e := range existing {
|
||||||
for n := range names {
|
if n.EqualFold(e) {
|
||||||
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
|
return e, nil
|
||||||
return errors.New("a model with that name already exists")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return n, nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CreateHandler(c *gin.Context) {
|
func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
|
@ -652,7 +654,8 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkNameExists(name); err != nil {
|
name, err := getExistingName(name)
|
||||||
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -958,14 +961,19 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
src, err := getExistingName(src)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
dst := model.ParseName(r.Destination)
|
dst := model.ParseName(r.Destination)
|
||||||
if !dst.IsValid() {
|
if !dst.IsValid() {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
dst, err = getExistingName(dst)
|
||||||
if err := checkNameExists(dst); err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,13 +7,18 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
@ -473,83 +478,129 @@ func Test_Routes(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCase(t *testing.T) {
|
func casingShuffle(s string) string {
|
||||||
|
rr := []rune(s)
|
||||||
|
for i := range rr {
|
||||||
|
if rand.N(2) == 0 {
|
||||||
|
rr[i] = unicode.ToUpper(rr[i])
|
||||||
|
} else {
|
||||||
|
rr[i] = unicode.ToLower(rr[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(rr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestCaseSensitivity(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
cases := []string{
|
r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
"mistral",
|
w.WriteHeader(http.StatusOK)
|
||||||
"llama3:latest",
|
io.WriteString(w, `{}`) //nolint:errcheck
|
||||||
"library/phi3:q4_0",
|
}))
|
||||||
"registry.ollama.ai/library/gemma:q5_K_M",
|
defer r.Close()
|
||||||
// TODO: host:port currently fails on windows (#4107)
|
|
||||||
// "localhost:5000/alice/bob:latest",
|
nameUsed := make(map[string]bool)
|
||||||
|
name := func() string {
|
||||||
|
const fqmn = "example/namespace/model:tag"
|
||||||
|
for {
|
||||||
|
v := casingShuffle(fqmn)
|
||||||
|
if nameUsed[v] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nameUsed[v] = true
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wantStableName := name()
|
||||||
|
|
||||||
|
// checkManifestList tests that there is strictly one manifest in the
|
||||||
|
// models directory, and that the manifest is for the model under test.
|
||||||
|
checkManifestList := func() {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
|
||||||
|
var entries []string
|
||||||
|
t.Logf("dir entries:")
|
||||||
|
fsys := os.DirFS(mandir)
|
||||||
|
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.Logf(" %s", fs.FormatDirEntry(info))
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
path = strings.TrimPrefix(path, mandir)
|
||||||
|
entries = append(entries, path)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to walk directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Errorf("len(got) = %d, want 1", len(entries))
|
||||||
|
return // do not use Fatal so following steps run
|
||||||
|
}
|
||||||
|
|
||||||
|
g := entries[0] // raw path
|
||||||
|
g = filepath.ToSlash(g)
|
||||||
|
w := model.ParseName(wantStableName).Filepath()
|
||||||
|
w = filepath.ToSlash(w)
|
||||||
|
if g != w {
|
||||||
|
t.Errorf("\ngot: %s\nwant: %s", g, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checkOK := func(w *httptest.ResponseRecorder) {
|
||||||
|
t.Helper()
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("code = %d, want 200", w.Code)
|
||||||
|
t.Logf("body: %s", w.Body.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var s Server
|
var s Server
|
||||||
for _, tt := range cases {
|
testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
t.Run(tt, func(t *testing.T) {
|
var d net.Dialer
|
||||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
|
||||||
Name: tt,
|
}
|
||||||
|
t.Cleanup(func() { testMakeRequestDialContext = nil })
|
||||||
|
|
||||||
|
t.Logf("creating")
|
||||||
|
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
// Start with the stable name, and later use a case-shuffled
|
||||||
|
// version.
|
||||||
|
Name: wantStableName,
|
||||||
|
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
}))
|
||||||
|
checkManifestList()
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
t.Logf("creating (again)")
|
||||||
t.Fatalf("expected status 200 got %d", w.Code)
|
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
}
|
Name: name(),
|
||||||
|
|
||||||
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("create", func(t *testing.T) {
|
|
||||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
||||||
Name: strings.ToUpper(tt),
|
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
}))
|
||||||
|
checkManifestList()
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
t.Logf("pulling")
|
||||||
t.Fatalf("expected status 500 got %d", w.Code)
|
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
|
||||||
}
|
Name: name(),
|
||||||
|
|
||||||
if !bytes.Equal(w.Body.Bytes(), expect) {
|
|
||||||
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("pull", func(t *testing.T) {
|
|
||||||
w := createRequest(t, s.PullHandler, api.PullRequest{
|
|
||||||
Name: strings.ToUpper(tt),
|
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
Insecure: true,
|
||||||
|
}))
|
||||||
|
checkManifestList()
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
t.Logf("copying")
|
||||||
t.Fatalf("expected status 500 got %d", w.Code)
|
checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
|
||||||
}
|
Source: name(),
|
||||||
|
Destination: name(),
|
||||||
if !bytes.Equal(w.Body.Bytes(), expect) {
|
}))
|
||||||
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
checkManifestList()
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("copy", func(t *testing.T) {
|
|
||||||
w := createRequest(t, s.CopyHandler, api.CopyRequest{
|
|
||||||
Source: tt,
|
|
||||||
Destination: strings.ToUpper(tt),
|
|
||||||
})
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status 500 got %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(w.Body.Bytes(), expect) {
|
|
||||||
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShow(t *testing.T) {
|
func TestShow(t *testing.T) {
|
||||||
|
|
|
@ -298,6 +298,13 @@ func (n Name) LogValue() slog.Value {
|
||||||
return slog.StringValue(n.String())
|
return slog.StringValue(n.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n Name) EqualFold(o Name) bool {
|
||||||
|
return strings.EqualFold(n.Host, o.Host) &&
|
||||||
|
strings.EqualFold(n.Namespace, o.Namespace) &&
|
||||||
|
strings.EqualFold(n.Model, o.Model) &&
|
||||||
|
strings.EqualFold(n.Tag, o.Tag)
|
||||||
|
}
|
||||||
|
|
||||||
func isValidLen(kind partKind, s string) bool {
|
func isValidLen(kind partKind, s string) bool {
|
||||||
switch kind {
|
switch kind {
|
||||||
case kindHost:
|
case kindHost:
|
||||||
|
|
Loading…
Reference in a new issue