replace done channel with file check
This commit is contained in:
parent
288814d3e4
commit
10199c5987
1 changed files with 24 additions and 29 deletions
|
@ -31,10 +31,8 @@ type blobDownload struct {
|
||||||
Total int64
|
Total int64
|
||||||
Completed atomic.Int64
|
Completed atomic.Int64
|
||||||
|
|
||||||
*os.File
|
|
||||||
Parts []*blobDownloadPart
|
Parts []*blobDownloadPart
|
||||||
|
|
||||||
done chan struct{}
|
|
||||||
context.CancelFunc
|
context.CancelFunc
|
||||||
references atomic.Int32
|
references atomic.Int32
|
||||||
}
|
}
|
||||||
|
@ -54,6 +52,14 @@ func (p *blobDownloadPart) Name() string {
|
||||||
}, "-")
|
}, "-")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *blobDownloadPart) StartsAt() int64 {
|
||||||
|
return p.Offset + p.Completed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *blobDownloadPart) StopsAt() int64 {
|
||||||
|
return p.Offset + p.Size
|
||||||
|
}
|
||||||
|
|
||||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -110,18 +116,16 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||||
|
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
|
||||||
b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer b.Close()
|
defer file.Close()
|
||||||
|
|
||||||
b.Truncate(b.Total)
|
file.Truncate(b.Total)
|
||||||
|
|
||||||
b.done = make(chan struct{}, 1)
|
|
||||||
defer close(b.done)
|
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
// TODO(mxyng): download concurrency should be configurable
|
||||||
g.SetLimit(64)
|
g.SetLimit(64)
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
|
@ -132,7 +136,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||||
i := i
|
i := i
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := 0; try < maxRetries; try++ {
|
||||||
err := b.downloadChunk(ctx, requestURL, i, opts)
|
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||||
|
err := b.downloadChunk(ctx, requestURL, w, part, opts)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
return err
|
return err
|
||||||
|
@ -152,31 +157,23 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := b.Close(); err != nil {
|
// explicitly close the file so we can rename it
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil {
|
if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Rename(b.File.Name(), b.Name); err != nil {
|
return os.Rename(file.Name(), b.Name)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||||
}
|
|
||||||
|
|
||||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error {
|
|
||||||
part := b.Parts[i]
|
|
||||||
|
|
||||||
offset := part.Offset + part.Completed
|
|
||||||
w := io.NewOffsetWriter(b.File, offset)
|
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
|
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||||
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
|
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -258,10 +255,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||||
ticker := time.NewTicker(60 * time.Millisecond)
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-b.done:
|
|
||||||
if b.Completed.Load() != b.Total {
|
|
||||||
return io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
@ -275,11 +268,13 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||||
})
|
})
|
||||||
|
|
||||||
if b.Completed.Load() >= b.Total {
|
if b.Completed.Load() >= b.Total {
|
||||||
<-b.done
|
// wait for the file to get renamed
|
||||||
|
if _, err := os.Stat(b.Name); err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type downloadOpts struct {
|
type downloadOpts struct {
|
||||||
mp ModelPath
|
mp ModelPath
|
||||||
|
|
Loading…
Reference in a new issue