cross repo mount

This commit is contained in:
Michael Yang 2023-08-14 15:07:00 -07:00
parent 2ab20095b3
commit f594c8eb91

View file

@ -13,6 +13,7 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strconv" "strconv"
@ -94,6 +95,7 @@ type Layer struct {
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
Digest string `json:"digest"` Digest string `json:"digest"`
Size int `json:"size"` Size int `json:"size"`
From string `json:"from,omitempty"`
} }
type LayerReader struct { type LayerReader struct {
@ -270,7 +272,8 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args embed.model = c.Args
mf, err := GetManifest(ParseModelPath(c.Args)) mp := ParseModelPath(c.Args)
mf, err := GetManifest(mp)
if err != nil { if err != nil {
modelFile, err := filenameWithPath(path, c.Args) modelFile, err := filenameWithPath(path, c.Args)
if err != nil { if err != nil {
@ -327,6 +330,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if err != nil { if err != nil {
return err return err
} }
newLayer.From = mp.GetNamespaceRepository()
layers = append(layers, newLayer) layers = append(layers, newLayer)
} }
} }
@ -451,8 +455,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
layers = append(layers, cfg) layers = append(layers, cfg)
err = SaveLayers(layers, fn, false) if err := SaveLayers(layers, fn, false); err != nil {
if err != nil {
return err return err
} }
@ -896,14 +899,24 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
Total: layer.Size, Total: layer.Size,
}) })
location, err := startUpload(ctx, mp, regOpts) location, err := startUpload(ctx, mp, layer, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err) log.Printf("couldn't start upload: %v", err)
return err return err
} }
err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn) if strings.HasPrefix(path.Base(location), "sha256:") {
if err != nil { layer.Digest = path.Base(location)
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
continue
}
if err := uploadBlobChunked(ctx, mp, location, layer, regOpts, fn); err != nil {
log.Printf("error uploading blob: %v", err) log.Printf("error uploading blob: %v", err)
return err return err
} }
@ -1063,8 +1076,11 @@ func GetSHA256Digest(r io.Reader) (string, int) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
} }
func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) { func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
if layer.From != "" {
url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From)
}
resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts) resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
if err != nil { if err != nil {
@ -1074,7 +1090,7 @@ func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (s
defer resp.Body.Close() defer resp.Body.Close()
// Check for success // Check for success
if resp.StatusCode != http.StatusAccepted { if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusCreated {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
} }