Merge pull request #110 from jmorganca/fix-pull-0-bytes

fix pull 0 bytes on completed layer
This commit is contained in:
Michael Yang 2023-07-18 19:38:59 -07:00 committed by GitHub
commit a6d03dd510
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 63 deletions

View file

@ -160,11 +160,11 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
type PullProgressFunc func(PullProgress) error
type PullProgressFunc func(ProgressResponse) error
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
var resp PullProgress
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
@ -173,11 +173,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
})
}
type PushProgressFunc func(PushProgress) error
type PushProgressFunc func(ProgressResponse) error
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
var resp PushProgress
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}

View file

@ -43,12 +43,11 @@ type PullRequest struct {
Password string `json:"password"`
}
type PullProgress struct {
type ProgressResponse struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
Total int `json:"total,omitempty"`
Completed int `json:"completed,omitempty"`
Percent float64 `json:"percent,omitempty"`
}
type PushRequest struct {
@ -57,14 +56,6 @@ type PushRequest struct {
Password string `json:"password"`
}
type PushProgress struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
Total int `json:"total,omitempty"`
Completed int `json:"completed,omitempty"`
Percent float64 `json:"percent,omitempty"`
}
type ListResponse struct {
Models []ListResponseModel `json:"models"`
}

View file

@ -89,7 +89,7 @@ func push(cmd *cobra.Command, args []string) error {
client := api.NewClient()
request := api.PushRequest{Name: args[0]}
fn := func(resp api.PushProgress) error {
fn := func(resp api.ProgressResponse) error {
fmt.Println(resp.Status)
return nil
}
@ -135,25 +135,23 @@ func RunPull(cmd *cobra.Command, args []string) error {
func pull(model string) error {
client := api.NewClient()
var currentDigest string
var bar *progressbar.ProgressBar
currentLayer := ""
request := api.PullRequest{Name: model}
fn := func(resp api.PullProgress) error {
if resp.Digest != currentLayer && resp.Digest != "" {
if currentLayer != "" {
fmt.Println()
}
currentLayer = resp.Digest
layerStr := resp.Digest[7:23] + "..."
fn := func(resp api.ProgressResponse) error {
if resp.Digest != currentDigest && resp.Digest != "" {
currentDigest = resp.Digest
bar = progressbar.DefaultBytes(
int64(resp.Total),
"pulling "+layerStr,
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
)
} else if resp.Digest == currentLayer && resp.Digest != "" {
bar.Set(resp.Completed)
} else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed)
} else {
currentLayer = ""
currentDigest = ""
fmt.Println(resp.Status)
}
return nil

View file

@ -445,13 +445,14 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
return layer, nil
}
func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
func PushModel(name, username, password string, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn("retrieving manifest", "", 0, 0, 0)
fn(api.ProgressResponse{Status: "retrieving manifest"})
manifest, err := GetManifest(mp)
if err != nil {
fn("couldn't retrieve manifest", "", 0, 0, 0)
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
@ -473,11 +474,21 @@ func PushModel(name, username, password string, fn func(status, digest string, T
if exists {
completed += layer.Size
fn("using existing layer", layer.Digest, total, completed, float64(completed)/float64(total))
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: total,
Completed: completed,
})
continue
}
fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total))
fn(api.ProgressResponse{
Status: "starting upload",
Digest: layer.Digest,
Total: total,
Completed: completed,
})
location, err := startUpload(mp, username, password)
if err != nil {
@ -491,10 +502,19 @@ func PushModel(name, username, password string, fn func(status, digest string, T
return err
}
completed += layer.Size
fn("upload complete", layer.Digest, total, completed, float64(completed)/float64(total))
fn(api.ProgressResponse{
Status: "upload complete",
Digest: layer.Digest,
Total: total,
Completed: completed,
})
}
fn("pushing manifest", "", total, completed, float64(completed/total))
fn(api.ProgressResponse{
Status: "pushing manifest",
Total: total,
Completed: completed,
})
url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{
"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
@ -517,15 +537,19 @@ func PushModel(name, username, password string, fn func(status, digest string, T
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
}
fn("success", "", total, completed, 1.0)
fn(api.ProgressResponse{
Status: "success",
Total: total,
Completed: completed,
})
return nil
}
func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
func PullModel(name, username, password string, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn("pulling manifest", "", 0, 0, 0)
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err := pullModelManifest(mp, username, password)
if err != nil {
@ -543,16 +567,15 @@ func PullModel(name, username, password string, fn func(status, digest string, T
total += manifest.Config.Size
for _, layer := range layers {
fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
fn(api.ProgressResponse{Status: fmt.Sprintf("error downloading: %v", err), Digest: layer.Digest})
return err
}
completed += layer.Size
fn("download complete", layer.Digest, total, completed, float64(completed)/float64(total))
}
fn("writing manifest", "", total, completed, 1.0)
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
@ -570,7 +593,7 @@ func PullModel(name, username, password string, fn func(status, digest string, T
return err
}
fn("success", "", total, completed, 1.0)
fn(api.ProgressResponse{Status: "success"})
return nil
}
@ -722,16 +745,20 @@ func uploadBlob(location string, layer *Layer, username string, password string)
return nil
}
func downloadBlob(mp ModelPath, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
func downloadBlob(mp ModelPath, digest string, username, password string, fn func(api.ProgressResponse)) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
_, err = os.Stat(fp)
if !os.IsNotExist(err) {
if fi, _ := os.Stat(fp); fi != nil {
// we already have the file, so return
log.Printf("already have %s\n", digest)
fn(api.ProgressResponse{
Digest: digest,
Total: int(fi.Size()),
Completed: int(fi.Size()),
})
return nil
}
@ -780,10 +807,21 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun
total := remaining + completed
for {
fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total))
fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", digest),
Digest: digest,
Total: int(total),
Completed: int(completed),
})
if completed >= total {
if err := os.Rename(fp+"-partial", fp); err != nil {
fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1)
fn(api.ProgressResponse{
Status: fmt.Sprintf("error renaming file: %v", err),
Digest: digest,
Total: int(total),
Completed: int(completed),
})
return err
}

View file

@ -101,15 +101,10 @@ func pull(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) {
ch <- api.PullProgress{
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
fn := func(r api.ProgressResponse) {
ch <- r
}
if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -129,15 +124,10 @@ func push(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) {
ch <- api.PushProgress{
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
fn := func(r api.ProgressResponse) {
ch <- r
}
if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return