From f594c8eb9162bfb6c74531831ceb4a99d6533fbe Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 14 Aug 2023 15:07:00 -0700 Subject: [PATCH] cross repo mount --- server/images.go | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/server/images.go b/server/images.go index 2177fe6d..f7a7bbe0 100644 --- a/server/images.go +++ b/server/images.go @@ -13,6 +13,7 @@ import ( "log" "net/http" "os" + "path" "path/filepath" "reflect" "strconv" @@ -94,6 +95,7 @@ type Layer struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` Size int `json:"size"` + From string `json:"from,omitempty"` } type LayerReader struct { @@ -270,7 +272,8 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api case "model": fn(api.ProgressResponse{Status: "looking for model"}) embed.model = c.Args - mf, err := GetManifest(ParseModelPath(c.Args)) + mp := ParseModelPath(c.Args) + mf, err := GetManifest(mp) if err != nil { modelFile, err := filenameWithPath(path, c.Args) if err != nil { @@ -327,6 +330,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api if err != nil { return err } + newLayer.From = mp.GetNamespaceRepository() layers = append(layers, newLayer) } } @@ -451,8 +455,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } layers = append(layers, cfg) - err = SaveLayers(layers, fn, false) - if err != nil { + if err := SaveLayers(layers, fn, false); err != nil { return err } @@ -896,14 +899,24 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu Total: layer.Size, }) - location, err := startUpload(ctx, mp, regOpts) + location, err := startUpload(ctx, mp, layer, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) return err } - err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn) - if err != nil { + if strings.HasPrefix(path.Base(location), "sha256:") { + layer.Digest = path.Base(location) + fn(api.ProgressResponse{ + Status: "using existing layer", + Digest: layer.Digest, + Total: layer.Size, + Completed: layer.Size, + }) + continue + } + + if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil { log.Printf("error uploading blob: %v", err) return err } @@ -1063,8 +1076,11 @@ func GetSHA256Digest(r io.Reader) (string, int) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) } -func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) { +func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) { url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) + if layer.From != "" { + url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From) + } resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts) if err != nil { @@ -1074,7 +1090,7 @@ func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (s defer resp.Body.Close() // Check for success - if resp.StatusCode != http.StatusAccepted { + if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) }