diff --git a/api/client.go b/api/client.go index ea61eb79..8078c428 100644 --- a/api/client.go +++ b/api/client.go @@ -160,11 +160,11 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate }) } -type PullProgressFunc func(PullProgress) error +type PullProgressFunc func(ProgressResponse) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { - var resp PullProgress + var resp ProgressResponse if err := json.Unmarshal(bts, &resp); err != nil { return err } @@ -173,11 +173,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc }) } -type PushProgressFunc func(PushProgress) error +type PushProgressFunc func(ProgressResponse) error func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { - var resp PushProgress + var resp ProgressResponse if err := json.Unmarshal(bts, &resp); err != nil { return err } diff --git a/api/types.go b/api/types.go index f9b48aa5..b14d6811 100644 --- a/api/types.go +++ b/api/types.go @@ -43,12 +43,11 @@ type PullRequest struct { Password string `json:"password"` } -type PullProgress struct { +type ProgressResponse struct { Status string `json:"status"` Digest string `json:"digest,omitempty"` Total int `json:"total,omitempty"` Completed int `json:"completed,omitempty"` - Percent float64 `json:"percent,omitempty"` } type PushRequest struct { @@ -57,14 +56,6 @@ type PushRequest struct { Password string `json:"password"` } -type PushProgress struct { - Status string `json:"status"` - Digest string `json:"digest,omitempty"` - Total int `json:"total,omitempty"` - Completed int `json:"completed,omitempty"` - Percent float64 `json:"percent,omitempty"` -} - type ListResponse struct { Models []ListResponseModel `json:"models"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 9478acc9..0552e1ba 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -83,7 +83,7 @@ func push(cmd *cobra.Command, args []string) error { client := api.NewClient() request := api.PushRequest{Name: args[0]} - fn := func(resp api.PushProgress) error { + fn := func(resp api.ProgressResponse) error { fmt.Println(resp.Status) return nil } @@ -129,25 +129,23 @@ func RunPull(cmd *cobra.Command, args []string) error { func pull(model string) error { client := api.NewClient() + var currentDigest string var bar *progressbar.ProgressBar - currentLayer := "" request := api.PullRequest{Name: model} - fn := func(resp api.PullProgress) error { - if resp.Digest != currentLayer && resp.Digest != "" { - if currentLayer != "" { - fmt.Println() - } - currentLayer = resp.Digest - layerStr := resp.Digest[7:23] + "..." + fn := func(resp api.ProgressResponse) error { + if resp.Digest != currentDigest && resp.Digest != "" { + currentDigest = resp.Digest bar = progressbar.DefaultBytes( int64(resp.Total), - "pulling "+layerStr, + fmt.Sprintf("pulling %s...", resp.Digest[7:19]), ) - } else if resp.Digest == currentLayer && resp.Digest != "" { + + bar.Set(resp.Completed) + } else if resp.Digest == currentDigest && resp.Digest != "" { bar.Set(resp.Completed) } else { - currentLayer = "" + currentDigest = "" fmt.Println(resp.Status) } return nil diff --git a/server/images.go b/server/images.go index d176ca57..3e461301 100644 --- a/server/images.go +++ b/server/images.go @@ -445,13 +445,14 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { return layer, nil } -func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { +func PushModel(name, username, password string, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) - fn("retrieving manifest", "", 0, 0, 0) + fn(api.ProgressResponse{Status: "retrieving manifest"}) + manifest, err := GetManifest(mp) if err != nil { - fn("couldn't retrieve manifest", "", 0, 0, 0) + fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) return err } @@ -473,11 +474,21 @@ func PushModel(name, username, password string, fn func(status, digest string, T if exists { completed += layer.Size - fn("using existing layer", layer.Digest, total, completed, float64(completed)/float64(total)) + fn(api.ProgressResponse{ + Status: "using existing layer", + Digest: layer.Digest, + Total: total, + Completed: completed, + }) continue } - fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total)) + fn(api.ProgressResponse{ + Status: "starting upload", + Digest: layer.Digest, + Total: total, + Completed: completed, + }) location, err := startUpload(mp, username, password) if err != nil { @@ -491,10 +502,19 @@ func PushModel(name, username, password string, fn func(status, digest string, T return err } completed += layer.Size - fn("upload complete", layer.Digest, total, completed, float64(completed)/float64(total)) + fn(api.ProgressResponse{ + Status: "upload complete", + Digest: layer.Digest, + Total: total, + Completed: completed, + }) } - fn("pushing manifest", "", total, completed, float64(completed/total)) + fn(api.ProgressResponse{ + Status: "pushing manifest", + Total: total, + Completed: completed, + }) url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag) headers := map[string]string{ "Content-Type": "application/vnd.docker.distribution.manifest.v2+json", @@ -517,15 +537,19 @@ func PushModel(name, username, password string, fn func(status, digest string, T return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) } - fn("success", "", total, completed, 1.0) + fn(api.ProgressResponse{ + Status: "success", + Total: total, + Completed: completed, + }) return nil } -func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { +func PullModel(name, username, password string, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) - fn("pulling manifest", "", 0, 0, 0) + fn(api.ProgressResponse{Status: "pulling manifest"}) manifest, err := pullModelManifest(mp, username, password) if err != nil { @@ -543,16 +567,15 @@ func PullModel(name, username, password string, fn func(status, digest string, T total += manifest.Config.Size for _, layer := range layers { - fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total)) if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil { - fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0) + fn(api.ProgressResponse{Status: fmt.Sprintf("error downloading: %v", err), Digest: layer.Digest}) return err } + completed += layer.Size - fn("download complete", layer.Digest, total, completed, float64(completed)/float64(total)) } - fn("writing manifest", "", total, completed, 1.0) + fn(api.ProgressResponse{Status: "writing manifest"}) manifestJSON, err := json.Marshal(manifest) if err != nil { @@ -570,7 +593,7 @@ func PullModel(name, username, password string, fn func(status, digest string, T return err } - fn("success", "", total, completed, 1.0) + fn(api.ProgressResponse{Status: "success"}) return nil } @@ -722,16 +745,20 @@ func uploadBlob(location string, layer *Layer, username string, password string) return nil } -func downloadBlob(mp ModelPath, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { +func downloadBlob(mp ModelPath, digest string, username, password string, fn func(api.ProgressResponse)) error { fp, err := GetBlobsPath(digest) if err != nil { return err } - _, err = os.Stat(fp) - if !os.IsNotExist(err) { + if fi, _ := os.Stat(fp); fi != nil { // we already have the file, so return - log.Printf("already have %s\n", digest) + fn(api.ProgressResponse{ + Digest: digest, + Total: int(fi.Size()), + Completed: int(fi.Size()), + }) + return nil } @@ -780,10 +807,21 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun total := remaining + completed for { - fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total)) + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", digest), + Digest: digest, + Total: int(total), + Completed: int(completed), + }) + if completed >= total { if err := os.Rename(fp+"-partial", fp); err != nil { - fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1) + fn(api.ProgressResponse{ + Status: fmt.Sprintf("error renaming file: %v", err), + Digest: digest, + Total: int(total), + Completed: int(completed), + }) return err } diff --git a/server/routes.go b/server/routes.go index 48060d4e..b75ae1eb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -101,15 +101,10 @@ func pull(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - fn := func(status, digest string, total, completed int, percent float64) { - ch <- api.PullProgress{ - Status: status, - Digest: digest, - Total: total, - Completed: completed, - Percent: percent, - } + fn := func(r api.ProgressResponse) { + ch <- r } + if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -129,15 +124,10 @@ func push(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - fn := func(status, digest string, total, completed int, percent float64) { - ch <- api.PushProgress{ - Status: status, - Digest: digest, - Total: total, - Completed: completed, - Percent: percent, - } + fn := func(r api.ProgressResponse) { + ch <- r } + if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return