diff --git a/cmd/cmd.go b/cmd/cmd.go index 99033614..43b186e8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -48,7 +48,13 @@ func create(cmd *cobra.Command, args []string) error { } func RunRun(cmd *cobra.Command, args []string) error { - _, err := os.Stat(args[0]) + mp := server.ParseModelPath(args[0]) + fp, err := mp.GetManifestPath(false) + if err != nil { + return err + } + + _, err = os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): if err := pull(args[0]); err != nil { diff --git a/server/images.go b/server/images.go index 4a7aa206..2955586e 100644 --- a/server/images.go +++ b/server/images.go @@ -22,8 +22,6 @@ import ( "github.com/jmorganca/ollama/parser" ) -var DefaultRegistry string = "https://registry.ollama.ai" - type Model struct { Name string `json:"name"` ModelPath string @@ -61,27 +59,13 @@ type RootFS struct { DiffIDs []string `json:"diff_ids"` } -func modelsDir(part ...string) (string, error) { - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - - path := filepath.Join(home, ".ollama", "models", filepath.Join(part...)) - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return "", err - } - - return path, nil -} - -func GetManifest(name string) (*ManifestV2, error) { - fp, err := modelsDir("manifests", name) +func GetManifest(mp ModelPath) (*ManifestV2, error) { + fp, err := mp.GetManifestPath(false) if err != nil { return nil, err } if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) { - return nil, fmt.Errorf("couldn't find model '%s'", name) + return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname()) } var manifest *ManifestV2 @@ -101,17 +85,19 @@ func GetManifest(name string) (*ManifestV2, error) { } func GetModel(name string) (*Model, error) { - manifest, err := GetManifest(name) + mp := ParseModelPath(name) + + manifest, err := GetManifest(mp) if err != nil { return nil, err } model := &Model{ - Name: name, + Name: mp.GetFullTagname(), } for _, layer := range manifest.Layers { - filename, err := modelsDir("blobs", layer.Digest) + filename, err := GetBlobsPath(layer.Digest) if err != nil { return nil, err } @@ -174,7 +160,7 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error { switch c.Name { case "model": fn("looking for model") - mf, err := GetManifest(c.Arg) + mf, err := GetManifest(ParseModelPath(c.Arg)) if err != nil { // if we couldn't read the manifest, try getting the bin file fp, err := getAbsPath(c.Arg) @@ -293,7 +279,7 @@ func removeLayerFromLayers(layers []*LayerWithBuffer, mediaType string) []*Layer func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) error { // Write each of the layers to disk for _, layer := range layers { - fp, err := modelsDir("blobs", layer.Digest) + fp, err := GetBlobsPath(layer.Digest) if err != nil { return err } @@ -321,6 +307,8 @@ func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) e } func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error { + mp := ParseModelPath(name) + manifest := ManifestV2{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", @@ -337,7 +325,7 @@ func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error { return err } - fp, err := modelsDir("manifests", name) + fp, err := mp.GetManifestPath(true) if err != nil { return err } @@ -345,7 +333,7 @@ func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error { } func GetLayerWithBufferFromLayer(layer *Layer) (*LayerWithBuffer, error) { - fp, err := modelsDir("blobs", layer.Digest) + fp, err := GetBlobsPath(layer.Digest) if err != nil { return nil, err } @@ -456,28 +444,15 @@ func CreateLayer(f io.Reader) (*LayerWithBuffer, error) { } func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + mp := ParseModelPath(name) + fn("retrieving manifest", "", 0, 0, 0) - manifest, err := GetManifest(name) + manifest, err := GetManifest(mp) if err != nil { fn("couldn't retrieve manifest", "", 0, 0, 0) return err } - var repoName string - var tag string - - comps := strings.Split(name, ":") - switch { - case len(comps) < 1 || len(comps) > 2: - return fmt.Errorf("repository name was invalid") - case len(comps) == 1: - repoName = comps[0] - tag = "latest" - case len(comps) == 2: - repoName = comps[0] - tag = comps[1] - } - var layers []*Layer var total int var completed int @@ -489,7 +464,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T total += manifest.Config.Size for _, layer := range layers { - exists, err := checkBlobExistence(DefaultRegistry, repoName, layer.Digest, username, password) + exists, err := checkBlobExistence(mp, layer.Digest, username, password) if err != nil { return err } @@ -502,7 +477,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total)) - location, err := startUpload(DefaultRegistry, repoName, username, password) + location, err := startUpload(mp, username, password) if err != nil { log.Printf("couldn't start upload: %v", err) return err @@ -518,7 +493,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T } fn("pushing manifest", "", total, completed, float64(completed/total)) - url := fmt.Sprintf("%s/v2/%s/manifests/%s", DefaultRegistry, repoName, tag) + url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag) headers := map[string]string{ "Content-Type": "application/vnd.docker.distribution.manifest.v2+json", } @@ -546,30 +521,15 @@ func PushModel(name, username, password string, fn func(status, digest string, T } func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { - var repoName string - var tag string - - comps := strings.Split(name, ":") - switch { - case len(comps) < 1 || len(comps) > 2: - return fmt.Errorf("repository name was invalid") - case len(comps) == 1: - repoName = comps[0] - tag = "latest" - case len(comps) == 2: - repoName = comps[0] - tag = comps[1] - } + mp := ParseModelPath(name) fn("pulling manifest", "", 0, 0, 0) - manifest, err := pullModelManifest(DefaultRegistry, repoName, tag, username, password) + manifest, err := pullModelManifest(mp, username, password) if err != nil { return fmt.Errorf("pull model manifest: %q", err) } - log.Printf("manifest = %#v", manifest) - var layers []*Layer var total int var completed int @@ -582,7 +542,7 @@ func PullModel(name, username, password string, fn func(status, digest string, T for _, layer := range layers { fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total)) - if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil { + if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil { fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0) return err } @@ -597,16 +557,11 @@ func PullModel(name, username, password string, fn func(status, digest string, T return err } - fp, err := modelsDir("manifests", name) + fp, err := mp.GetManifestPath(true) if err != nil { return err } - err = os.MkdirAll(path.Dir(fp), 0o700) - if err != nil { - return fmt.Errorf("make manifests directory: %w", err) - } - err = os.WriteFile(fp, manifestJSON, 0644) if err != nil { log.Printf("couldn't write to %s", fp) @@ -618,8 +573,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T return nil } -func pullModelManifest(registryURL, repoName, tag, username, password string) (*ManifestV2, error) { - url := fmt.Sprintf("%s/v2/%s/manifests/%s", registryURL, repoName, tag) +func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, error) { + url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag) headers := map[string]string{ "Accept": "application/vnd.docker.distribution.manifest.v2+json", } @@ -682,8 +637,8 @@ func GetSHA256Digest(data *bytes.Buffer) (string, int) { return "sha256:" + hex.EncodeToString(hash[:]), len(layerBytes) } -func startUpload(registryURL string, repositoryName string, username string, password string) (string, error) { - url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", registryURL, repositoryName) +func startUpload(mp ModelPath, username string, password string) (string, error) { + url := fmt.Sprintf("%s://%s/v2/%s/blobs/uploads/", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository()) resp, err := makeRequest("POST", url, nil, nil, username, password) if err != nil { @@ -708,8 +663,8 @@ func startUpload(registryURL string, repositoryName string, username string, pas } // Function to check if a blob already exists in the Docker registry -func checkBlobExistence(registryURL string, repositoryName string, digest string, username string, password string) (bool, error) { - url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repositoryName, digest) +func checkBlobExistence(mp ModelPath, digest string, username string, password string) (bool, error) { + url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest) resp, err := makeRequest("HEAD", url, nil, nil, username, password) if err != nil { @@ -735,7 +690,7 @@ func uploadBlob(location string, layer *Layer, username string, password string) // TODO allow canceling uploads via DELETE // TODO allow cross repo blob mount - fp, err := modelsDir("blobs", layer.Digest) + fp, err := GetBlobsPath(layer.Digest) if err != nil { return err } @@ -761,8 +716,8 @@ func uploadBlob(location string, layer *Layer, username string, password string) return nil } -func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { - fp, err := modelsDir("blobs", digest) +func downloadBlob(mp ModelPath, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + fp, err := GetBlobsPath(digest) if err != nil { return err } @@ -786,7 +741,7 @@ func downloadBlob(registryURL, repoName, digest string, username, password strin size = fi.Size() } - url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest) + url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest) headers := map[string]string{ "Range": fmt.Sprintf("bytes=%d-", size), } diff --git a/server/modelpath.go b/server/modelpath.go new file mode 100644 index 00000000..d23c1933 --- /dev/null +++ b/server/modelpath.go @@ -0,0 +1,106 @@ +package server + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +type ModelPath struct { + ProtocolScheme string + Registry string + Namespace string + Repository string + Tag string +} + +const ( + DefaultRegistry = "registry.ollama.ai" + DefaultNamespace = "library" + DefaultTag = "latest" + DefaultProtocolScheme = "https" +) + +func ParseModelPath(name string) ModelPath { + slashParts := strings.Split(name, "/") + var registry, namespace, repository, tag string + + switch len(slashParts) { + case 3: + registry = slashParts[0] + namespace = slashParts[1] + repository = strings.Split(slashParts[2], ":")[0] + case 2: + registry = DefaultRegistry + namespace = slashParts[0] + repository = strings.Split(slashParts[1], ":")[0] + case 1: + registry = DefaultRegistry + namespace = DefaultNamespace + repository = strings.Split(slashParts[0], ":")[0] + default: + fmt.Println("Invalid image format.") + return ModelPath{} + } + + colonParts := strings.Split(name, ":") + if len(colonParts) == 2 { + tag = colonParts[1] + } else { + tag = DefaultTag + } + + return ModelPath{ + ProtocolScheme: DefaultProtocolScheme, + Registry: registry, + Namespace: namespace, + Repository: repository, + Tag: tag, + } +} + +func (mp ModelPath) GetNamespaceRepository() string { + return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository) +} + +func (mp ModelPath) GetFullTagname() string { + return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag) +} + +func (mp ModelPath) GetShortTagname() string { + if mp.Registry == DefaultRegistry && mp.Namespace == DefaultNamespace { + return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag) + } + return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag) +} + +func (mp ModelPath) GetManifestPath(createDir bool) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + path := filepath.Join(home, ".ollama", "models", "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag) + if createDir { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return "", err + } + } + + return path, nil +} + +func GetBlobsPath(digest string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + path := filepath.Join(home, ".ollama", "models", "blobs") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return "", err + } + + return filepath.Join(path, digest), nil +}