Merge pull request #1229 from jmorganca/mxyng/calculate-as-you-go

revert checksum calculation to calculate-as-you-go
This commit is contained in:
Michael Yang 2023-11-30 10:54:38 -08:00 committed by GitHub
commit b56e92470a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,6 +5,7 @@ import (
"crypto/md5" "crypto/md5"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"log" "log"
"math" "math"
@ -102,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
} }
// set part.N to the current number of parts // set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size}) b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
offset += size offset += size
} }
@ -147,14 +148,13 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
g.Go(func() error { g.Go(func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts) err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
return err return err
case errors.Is(err, errMaxRetriesExceeded): case errors.Is(err, errMaxRetriesExceeded):
return err return err
case err != nil: case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep) log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep) time.Sleep(sleep)
@ -176,17 +176,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
requestURL := <-b.nextURL requestURL := <-b.nextURL
var sb strings.Builder
// calculate md5 checksum and add it to the commit request // calculate md5 checksum and add it to the commit request
var sb strings.Builder
for _, part := range b.Parts { for _, part := range b.Parts {
hash := md5.New() sb.Write(part.Sum(nil))
if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
b.err = err
return
}
sb.Write(hash.Sum(nil))
} }
md5sum := md5.Sum([]byte(sb.String())) md5sum := md5.Sum([]byte(sb.String()))
@ -201,27 +194,25 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
headers.Set("Content-Length", "0") headers.Set("Content-Length", "0")
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) var resp *http.Response
if err != nil { resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
b.err = err
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return break
} } else if err != nil {
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep) log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
time.Sleep(sleep) time.Sleep(sleep)
continue continue
} }
defer resp.Body.Close() defer resp.Body.Close()
break
b.err = nil
b.done = true
return
} }
b.err = err
b.done = true
} }
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@ -232,8 +223,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
} }
sr := io.NewSectionReader(b.file, part.Offset, part.Size) sr := io.NewSectionReader(b.file, part.Offset, part.Size)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil { if err != nil {
w.Rollback()
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -245,11 +241,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
nextURL, err := url.Parse(location) nextURL, err := url.Parse(location)
if err != nil { if err != nil {
w.Rollback()
return err return err
} }
switch { switch {
case resp.StatusCode == http.StatusTemporaryRedirect: case resp.StatusCode == http.StatusTemporaryRedirect:
w.Rollback()
b.nextURL <- nextURL b.nextURL <- nextURL
redirectURL, err := resp.Location() redirectURL, err := resp.Location()
@ -259,14 +257,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
// retry uploading to the redirect URL // retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil) err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
return err return err
case errors.Is(err, errMaxRetriesExceeded): case errors.Is(err, errMaxRetriesExceeded):
return err return err
case err != nil: case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep) log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep) time.Sleep(sleep)
@ -279,6 +276,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
auth := resp.Header.Get("www-authenticate") auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth) authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir) token, err := getAuthToken(ctx, authRedir)
@ -289,6 +287,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
opts.Token = token opts.Token = token
fallthrough fallthrough
case resp.StatusCode >= http.StatusBadRequest: case resp.StatusCode >= http.StatusBadRequest:
w.Rollback()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return err
@ -301,6 +300,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
b.nextURL <- nextURL b.nextURL <- nextURL
} }
part.Hash = md5sum
return nil return nil
} }
@ -344,19 +344,23 @@ type blobUploadPart struct {
N int N int
Offset int64 Offset int64
Size int64 Size int64
hash.Hash
}
type progressWriter struct {
written int64 written int64
*blobUpload *blobUpload
} }
func (p *blobUploadPart) Write(b []byte) (n int, err error) { func (p *progressWriter) Write(b []byte) (n int, err error) {
n = len(b) n = len(b)
p.written += int64(n) p.written += int64(n)
p.Completed.Add(int64(n)) p.Completed.Add(int64(n))
return n, nil return n, nil
} }
func (p *blobUploadPart) Reset() { func (p *progressWriter) Rollback() {
p.Completed.Add(-int64(p.written)) p.Completed.Add(-p.written)
p.written = 0 p.written = 0
} }