handle unexpected eofs

This commit is contained in:
Michael Yang 2023-10-02 13:34:07 -07:00
parent 5b84404c64
commit 090d08422b

View file

@ -45,8 +45,6 @@ type blobDownloadPart struct {
} }
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
b.done = make(chan struct{}, 1)
partFilePaths, err := filepath.Glob(b.Name + "-partial-*") partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil { if err != nil {
return err return err
@ -109,6 +107,9 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
b.Truncate(b.Total) b.Truncate(b.Total)
b.done = make(chan struct{}, 1)
defer close(b.done)
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
g.SetLimit(64) g.SetLimit(64)
for i := range b.Parts { for i := range b.Parts {
@ -154,7 +155,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
return err return err
} }
close(b.done)
return nil return nil
} }
@ -174,14 +174,19 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, b)) n, err := io.Copy(w, io.TeeReader(resp.Body, b))
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, context.Canceled) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)
return err return err
} }
part.Completed += n part.Completed += n
return b.writePart(partName, part) if err := b.writePart(partName, part); err != nil {
return err
}
// return nil or context.Canceled
return err
} }
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) { func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
@ -221,6 +226,10 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
ticker := time.NewTicker(60 * time.Millisecond) ticker := time.NewTicker(60 * time.Millisecond)
for { for {
select { select {
case <-b.done:
if b.Completed.Load() != b.Total {
return io.ErrUnexpectedEOF
}
case <-ticker.C: case <-ticker.C:
case <-ctx.Done(): case <-ctx.Done():
if b.refCount.Add(-1) == 0 { if b.refCount.Add(-1) == 0 {