add modelpaths (#96)

This commit is contained in:
Patrick Devine 2023-07-17 22:44:21 -07:00 committed by GitHub
parent 1f45f7bb52
commit 4a28a2f093
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 80 deletions

View file

@ -48,7 +48,13 @@ func create(cmd *cobra.Command, args []string) error {
}
func RunRun(cmd *cobra.Command, args []string) error {
_, err := os.Stat(args[0])
mp := server.ParseModelPath(args[0])
fp, err := mp.GetManifestPath(false)
if err != nil {
return err
}
_, err = os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
if err := pull(args[0]); err != nil {

View file

@ -22,8 +22,6 @@ import (
"github.com/jmorganca/ollama/parser"
)
var DefaultRegistry string = "https://registry.ollama.ai"
type Model struct {
Name string `json:"name"`
ModelPath string
@ -61,27 +59,13 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"`
}
func modelsDir(part ...string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path := filepath.Join(home, ".ollama", "models", filepath.Join(part...))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return "", err
}
return path, nil
}
func GetManifest(name string) (*ManifestV2, error) {
fp, err := modelsDir("manifests", name)
func GetManifest(mp ModelPath) (*ManifestV2, error) {
fp, err := mp.GetManifestPath(false)
if err != nil {
return nil, err
}
if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("couldn't find model '%s'", name)
return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname())
}
var manifest *ManifestV2
@ -101,17 +85,19 @@ func GetManifest(name string) (*ManifestV2, error) {
}
func GetModel(name string) (*Model, error) {
manifest, err := GetManifest(name)
mp := ParseModelPath(name)
manifest, err := GetManifest(mp)
if err != nil {
return nil, err
}
model := &Model{
Name: name,
Name: mp.GetFullTagname(),
}
for _, layer := range manifest.Layers {
filename, err := modelsDir("blobs", layer.Digest)
filename, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@ -174,7 +160,7 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
switch c.Name {
case "model":
fn("looking for model")
mf, err := GetManifest(c.Arg)
mf, err := GetManifest(ParseModelPath(c.Arg))
if err != nil {
// if we couldn't read the manifest, try getting the bin file
fp, err := getAbsPath(c.Arg)
@ -293,7 +279,7 @@ func removeLayerFromLayers(layers []*LayerWithBuffer, mediaType string) []*Layer
func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) error {
// Write each of the layers to disk
for _, layer := range layers {
fp, err := modelsDir("blobs", layer.Digest)
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
@ -321,6 +307,8 @@ func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) e
}
func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
mp := ParseModelPath(name)
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
@ -337,7 +325,7 @@ func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
return err
}
fp, err := modelsDir("manifests", name)
fp, err := mp.GetManifestPath(true)
if err != nil {
return err
}
@ -345,7 +333,7 @@ func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
}
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerWithBuffer, error) {
fp, err := modelsDir("blobs", layer.Digest)
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@ -456,28 +444,15 @@ func CreateLayer(f io.Reader) (*LayerWithBuffer, error) {
}
func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
mp := ParseModelPath(name)
fn("retrieving manifest", "", 0, 0, 0)
manifest, err := GetManifest(name)
manifest, err := GetManifest(mp)
if err != nil {
fn("couldn't retrieve manifest", "", 0, 0, 0)
return err
}
var repoName string
var tag string
comps := strings.Split(name, ":")
switch {
case len(comps) < 1 || len(comps) > 2:
return fmt.Errorf("repository name was invalid")
case len(comps) == 1:
repoName = comps[0]
tag = "latest"
case len(comps) == 2:
repoName = comps[0]
tag = comps[1]
}
var layers []*Layer
var total int
var completed int
@ -489,7 +464,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T
total += manifest.Config.Size
for _, layer := range layers {
exists, err := checkBlobExistence(DefaultRegistry, repoName, layer.Digest, username, password)
exists, err := checkBlobExistence(mp, layer.Digest, username, password)
if err != nil {
return err
}
@ -502,7 +477,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T
fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total))
location, err := startUpload(DefaultRegistry, repoName, username, password)
location, err := startUpload(mp, username, password)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return err
@ -518,7 +493,7 @@ func PushModel(name, username, password string, fn func(status, digest string, T
}
fn("pushing manifest", "", total, completed, float64(completed/total))
url := fmt.Sprintf("%s/v2/%s/manifests/%s", DefaultRegistry, repoName, tag)
url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
}
@ -546,30 +521,15 @@ func PushModel(name, username, password string, fn func(status, digest string, T
}
func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
var repoName string
var tag string
comps := strings.Split(name, ":")
switch {
case len(comps) < 1 || len(comps) > 2:
return fmt.Errorf("repository name was invalid")
case len(comps) == 1:
repoName = comps[0]
tag = "latest"
case len(comps) == 2:
repoName = comps[0]
tag = comps[1]
}
mp := ParseModelPath(name)
fn("pulling manifest", "", 0, 0, 0)
manifest, err := pullModelManifest(DefaultRegistry, repoName, tag, username, password)
manifest, err := pullModelManifest(mp, username, password)
if err != nil {
return fmt.Errorf("pull model manifest: %q", err)
}
log.Printf("manifest = %#v", manifest)
var layers []*Layer
var total int
var completed int
@ -582,7 +542,7 @@ func PullModel(name, username, password string, fn func(status, digest string, T
for _, layer := range layers {
fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
return err
}
@ -597,16 +557,11 @@ func PullModel(name, username, password string, fn func(status, digest string, T
return err
}
fp, err := modelsDir("manifests", name)
fp, err := mp.GetManifestPath(true)
if err != nil {
return err
}
err = os.MkdirAll(path.Dir(fp), 0o700)
if err != nil {
return fmt.Errorf("make manifests directory: %w", err)
}
err = os.WriteFile(fp, manifestJSON, 0644)
if err != nil {
log.Printf("couldn't write to %s", fp)
@ -618,8 +573,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
return nil
}
func pullModelManifest(registryURL, repoName, tag, username, password string) (*ManifestV2, error) {
url := fmt.Sprintf("%s/v2/%s/manifests/%s", registryURL, repoName, tag)
func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, error) {
url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
}
@ -682,8 +637,8 @@ func GetSHA256Digest(data *bytes.Buffer) (string, int) {
return "sha256:" + hex.EncodeToString(hash[:]), len(layerBytes)
}
func startUpload(registryURL string, repositoryName string, username string, password string) (string, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", registryURL, repositoryName)
func startUpload(mp ModelPath, username string, password string) (string, error) {
url := fmt.Sprintf("%s://%s/v2/%s/blobs/uploads/", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository())
resp, err := makeRequest("POST", url, nil, nil, username, password)
if err != nil {
@ -708,8 +663,8 @@ func startUpload(registryURL string, repositoryName string, username string, pas
}
// Function to check if a blob already exists in the Docker registry
func checkBlobExistence(registryURL string, repositoryName string, digest string, username string, password string) (bool, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repositoryName, digest)
func checkBlobExistence(mp ModelPath, digest string, username string, password string) (bool, error) {
url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
resp, err := makeRequest("HEAD", url, nil, nil, username, password)
if err != nil {
@ -735,7 +690,7 @@ func uploadBlob(location string, layer *Layer, username string, password string)
// TODO allow canceling uploads via DELETE
// TODO allow cross repo blob mount
fp, err := modelsDir("blobs", layer.Digest)
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
@ -761,8 +716,8 @@ func uploadBlob(location string, layer *Layer, username string, password string)
return nil
}
func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
fp, err := modelsDir("blobs", digest)
func downloadBlob(mp ModelPath, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
@ -786,7 +741,7 @@ func downloadBlob(registryURL, repoName, digest string, username, password strin
size = fi.Size()
}
url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest)
headers := map[string]string{
"Range": fmt.Sprintf("bytes=%d-", size),
}

106
server/modelpath.go Normal file
View file

@ -0,0 +1,106 @@
package server
import (
"fmt"
"os"
"path/filepath"
"strings"
)
type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}
const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)
func ParseModelPath(name string) ModelPath {
slashParts := strings.Split(name, "/")
var registry, namespace, repository, tag string
switch len(slashParts) {
case 3:
registry = slashParts[0]
namespace = slashParts[1]
repository = strings.Split(slashParts[2], ":")[0]
case 2:
registry = DefaultRegistry
namespace = slashParts[0]
repository = strings.Split(slashParts[1], ":")[0]
case 1:
registry = DefaultRegistry
namespace = DefaultNamespace
repository = strings.Split(slashParts[0], ":")[0]
default:
fmt.Println("Invalid image format.")
return ModelPath{}
}
colonParts := strings.Split(name, ":")
if len(colonParts) == 2 {
tag = colonParts[1]
} else {
tag = DefaultTag
}
return ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: registry,
Namespace: namespace,
Repository: repository,
Tag: tag,
}
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry && mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
func (mp ModelPath) GetManifestPath(createDir bool) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path := filepath.Join(home, ".ollama", "models", "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
if createDir {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return "", err
}
}
return path, nil
}
func GetBlobsPath(digest string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path := filepath.Join(home, ".ollama", "models", "blobs")
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return "", err
}
return filepath.Join(path, digest), nil
}