Merge pull request #110 from jmorganca/fix-pull-0-bytes

fix pull 0 bytes on completed layer
This commit is contained in:
Michael Yang 2023-07-18 19:38:59 -07:00 committed by GitHub
commit a6d03dd510
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 63 deletions

View file

@ -160,11 +160,11 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
}) })
} }
type PullProgressFunc func(PullProgress) error type PullProgressFunc func(ProgressResponse) error
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
var resp PullProgress var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil { if err := json.Unmarshal(bts, &resp); err != nil {
return err return err
} }
@ -173,11 +173,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
}) })
} }
type PushProgressFunc func(PushProgress) error type PushProgressFunc func(ProgressResponse) error
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
var resp PushProgress var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil { if err := json.Unmarshal(bts, &resp); err != nil {
return err return err
} }

View file

@ -43,12 +43,11 @@ type PullRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
type PullProgress struct { type ProgressResponse struct {
Status string `json:"status"` Status string `json:"status"`
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
Total int `json:"total,omitempty"` Total int `json:"total,omitempty"`
Completed int `json:"completed,omitempty"` Completed int `json:"completed,omitempty"`
Percent float64 `json:"percent,omitempty"`
} }
type PushRequest struct { type PushRequest struct {
@ -57,14 +56,6 @@ type PushRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
type PushProgress struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
Total int `json:"total,omitempty"`
Completed int `json:"completed,omitempty"`
Percent float64 `json:"percent,omitempty"`
}
type ListResponse struct { type ListResponse struct {
Models []ListResponseModel `json:"models"` Models []ListResponseModel `json:"models"`
} }

View file

@ -89,7 +89,7 @@ func push(cmd *cobra.Command, args []string) error {
client := api.NewClient() client := api.NewClient()
request := api.PushRequest{Name: args[0]} request := api.PushRequest{Name: args[0]}
fn := func(resp api.PushProgress) error { fn := func(resp api.ProgressResponse) error {
fmt.Println(resp.Status) fmt.Println(resp.Status)
return nil return nil
} }
@ -135,25 +135,23 @@ func RunPull(cmd *cobra.Command, args []string) error {
func pull(model string) error { func pull(model string) error {
client := api.NewClient() client := api.NewClient()
var currentDigest string
var bar *progressbar.ProgressBar var bar *progressbar.ProgressBar
currentLayer := ""
request := api.PullRequest{Name: model} request := api.PullRequest{Name: model}
fn := func(resp api.PullProgress) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != currentLayer && resp.Digest != "" { if resp.Digest != currentDigest && resp.Digest != "" {
if currentLayer != "" { currentDigest = resp.Digest
fmt.Println()
}
currentLayer = resp.Digest
layerStr := resp.Digest[7:23] + "..."
bar = progressbar.DefaultBytes( bar = progressbar.DefaultBytes(
int64(resp.Total), int64(resp.Total),
"pulling "+layerStr, fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
) )
} else if resp.Digest == currentLayer && resp.Digest != "" {
bar.Set(resp.Completed)
} else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else { } else {
currentLayer = "" currentDigest = ""
fmt.Println(resp.Status) fmt.Println(resp.Status)
} }
return nil return nil

View file

@ -445,13 +445,14 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
return layer, nil return layer, nil
} }
func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn("retrieving manifest", "", 0, 0, 0) fn(api.ProgressResponse{Status: "retrieving manifest"})
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
fn("couldn't retrieve manifest", "", 0, 0, 0) fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err return err
} }
@ -473,11 +474,21 @@ func PushModel(name, username, password string, fn func(status, digest string, T
if exists { if exists {
completed += layer.Size completed += layer.Size
fn("using existing layer", layer.Digest, total, completed, float64(completed)/float64(total)) fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: total,
Completed: completed,
})
continue continue
} }
fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total)) fn(api.ProgressResponse{
Status: "starting upload",
Digest: layer.Digest,
Total: total,
Completed: completed,
})
location, err := startUpload(mp, username, password) location, err := startUpload(mp, username, password)
if err != nil { if err != nil {
@ -491,10 +502,19 @@ func PushModel(name, username, password string, fn func(status, digest string, T
return err return err
} }
completed += layer.Size completed += layer.Size
fn("upload complete", layer.Digest, total, completed, float64(completed)/float64(total)) fn(api.ProgressResponse{
Status: "upload complete",
Digest: layer.Digest,
Total: total,
Completed: completed,
})
} }
fn("pushing manifest", "", total, completed, float64(completed/total)) fn(api.ProgressResponse{
Status: "pushing manifest",
Total: total,
Completed: completed,
})
url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag) url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{ headers := map[string]string{
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json", "Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
@ -517,15 +537,19 @@ func PushModel(name, username, password string, fn func(status, digest string, T
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
} }
fn("success", "", total, completed, 1.0) fn(api.ProgressResponse{
Status: "success",
Total: total,
Completed: completed,
})
return nil return nil
} }
func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { func PullModel(name, username, password string, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn("pulling manifest", "", 0, 0, 0) fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err := pullModelManifest(mp, username, password) manifest, err := pullModelManifest(mp, username, password)
if err != nil { if err != nil {
@ -543,16 +567,15 @@ func PullModel(name, username, password string, fn func(status, digest string, T
total += manifest.Config.Size total += manifest.Config.Size
for _, layer := range layers { for _, layer := range layers {
fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil { if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0) fn(api.ProgressResponse{Status: fmt.Sprintf("error downloading: %v", err), Digest: layer.Digest})
return err return err
} }
completed += layer.Size completed += layer.Size
fn("download complete", layer.Digest, total, completed, float64(completed)/float64(total))
} }
fn("writing manifest", "", total, completed, 1.0) fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest) manifestJSON, err := json.Marshal(manifest)
if err != nil { if err != nil {
@ -570,7 +593,7 @@ func PullModel(name, username, password string, fn func(status, digest string, T
return err return err
} }
fn("success", "", total, completed, 1.0) fn(api.ProgressResponse{Status: "success"})
return nil return nil
} }
@ -722,16 +745,20 @@ func uploadBlob(location string, layer *Layer, username string, password string)
return nil return nil
} }
func downloadBlob(mp ModelPath, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { func downloadBlob(mp ModelPath, digest string, username, password string, fn func(api.ProgressResponse)) error {
fp, err := GetBlobsPath(digest) fp, err := GetBlobsPath(digest)
if err != nil { if err != nil {
return err return err
} }
_, err = os.Stat(fp) if fi, _ := os.Stat(fp); fi != nil {
if !os.IsNotExist(err) {
// we already have the file, so return // we already have the file, so return
log.Printf("already have %s\n", digest) fn(api.ProgressResponse{
Digest: digest,
Total: int(fi.Size()),
Completed: int(fi.Size()),
})
return nil return nil
} }
@ -780,10 +807,21 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
total := remaining + completed total := remaining + completed
for { for {
fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total)) fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", digest),
Digest: digest,
Total: int(total),
Completed: int(completed),
})
if completed >= total { if completed >= total {
if err := os.Rename(fp+"-partial", fp); err != nil { if err := os.Rename(fp+"-partial", fp); err != nil {
fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1) fn(api.ProgressResponse{
Status: fmt.Sprintf("error renaming file: %v", err),
Digest: digest,
Total: int(total),
Completed: int(completed),
})
return err return err
} }

View file

@ -101,15 +101,10 @@ func pull(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) { fn := func(r api.ProgressResponse) {
ch <- api.PullProgress{ ch <- r
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
} }
if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil { if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -129,15 +124,10 @@ func push(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) { fn := func(r api.ProgressResponse) {
ch <- api.PushProgress{ ch <- r
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
} }
if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil { if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return