implement ProgressWriter

This commit is contained in:
Michael Yang 2023-09-14 09:54:05 -07:00
parent ccc3e9ac6d
commit f0b398d17f

View file

@ -57,6 +57,12 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
// 95MB chunk size // 95MB chunk size
chunkSize := 95 * 1024 * 1024 chunkSize := 95 * 1024 * 1024
pw := ProgressWriter{
status: fmt.Sprintf("uploading %s", layer.Digest),
digest: layer.Digest,
total: layer.Size,
fn: fn,
}
for offset := int64(0); offset < int64(layer.Size); { for offset := int64(0); offset < int64(layer.Size); {
chunk := int64(layer.Size) - offset chunk := int64(layer.Size) - offset
@ -65,48 +71,16 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
} }
sectionReader := io.NewSectionReader(f, int64(offset), chunk) sectionReader := io.NewSectionReader(f, int64(offset), chunk)
var errStatus error
for try := 0; try < MaxRetries; try++ { for try := 0; try < MaxRetries; try++ {
ch := make(chan error, 1) errStatus = nil
r, w := io.Pipe()
defer r.Close()
go func() {
defer w.Close()
for chunked := int64(0); chunked < chunk; {
select {
case err := <-ch:
log.Printf("chunk interrupted: %v", err)
return
default:
n, err := io.CopyN(w, sectionReader, 1024*1024)
if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error reading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset),
})
return
}
chunked += n
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: int(offset) + int(chunked),
})
}
}
}()
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(chunk))) headers.Set("Content-Length", strconv.Itoa(int(chunk)))
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1)) headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts) resp, err := makeRequest(ctx, "PATCH", requestURL, headers, io.TeeReader(sectionReader, &pw), regOpts)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("error uploading chunk: %v", err), Status: fmt.Sprintf("error uploading chunk: %v", err),
@ -121,7 +95,7 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
switch { switch {
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
ch <- errors.New("unauthorized") errStatus = errors.New("unauthorized")
auth := resp.Header.Get("www-authenticate") auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth) authRedir := ParseAuthRedirectString(auth)
@ -131,7 +105,9 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
} }
regOpts.Token = token regOpts.Token = token
sectionReader = io.NewSectionReader(f, int64(offset), chunk)
pw.completed = int(offset)
sectionReader = io.NewSectionReader(f, offset, chunk)
continue continue
case resp.StatusCode >= http.StatusBadRequest: case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
@ -146,6 +122,10 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
break break
} }
if errStatus != nil {
return fmt.Errorf("max retries exceeded: %w", errStatus)
}
} }
values := requestURL.Query() values := requestURL.Query()
@ -170,3 +150,32 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r
} }
return nil return nil
} }
type ProgressWriter struct {
status string
digest string
bucket int
completed int
total int
fn func(api.ProgressResponse)
}
func (pw *ProgressWriter) Write(b []byte) (int, error) {
n := len(b)
pw.bucket += n
pw.completed += n
// throttle status updates to not spam the client
if pw.bucket >= 1024*1024 || pw.completed >= pw.total {
pw.fn(api.ProgressResponse{
Status: pw.status,
Digest: pw.digest,
Total: pw.total,
Completed: pw.completed,
})
pw.bucket = 0
}
return n, nil
}