From f0b398d17f4c907ad268e0baf9c8cba0188000cb Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 14 Sep 2023 09:54:05 -0700 Subject: [PATCH] implement ProgressWriter --- server/upload.go | 85 ++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/server/upload.go b/server/upload.go index a977581f..ea994c22 100644 --- a/server/upload.go +++ b/server/upload.go @@ -57,6 +57,12 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r // 95MB chunk size chunkSize := 95 * 1024 * 1024 + pw := ProgressWriter{ + status: fmt.Sprintf("uploading %s", layer.Digest), + digest: layer.Digest, + total: layer.Size, + fn: fn, + } for offset := int64(0); offset < int64(layer.Size); { chunk := int64(layer.Size) - offset @@ -65,48 +71,16 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r } sectionReader := io.NewSectionReader(f, int64(offset), chunk) + + var errStatus error for try := 0; try < MaxRetries; try++ { - ch := make(chan error, 1) - - r, w := io.Pipe() - defer r.Close() - go func() { - defer w.Close() - - for chunked := int64(0); chunked < chunk; { - select { - case err := <-ch: - log.Printf("chunk interrupted: %v", err) - return - default: - n, err := io.CopyN(w, sectionReader, 1024*1024) - if err != nil && !errors.Is(err, io.EOF) { - fn(api.ProgressResponse{ - Status: fmt.Sprintf("error reading chunk: %v", err), - Digest: layer.Digest, - Total: layer.Size, - Completed: int(offset), - }) - - return - } - - chunked += n - fn(api.ProgressResponse{ - Status: fmt.Sprintf("uploading %s", layer.Digest), - Digest: layer.Digest, - Total: layer.Size, - Completed: int(offset) + int(chunked), - }) - } - } - }() + errStatus = nil headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", strconv.Itoa(int(chunk))) headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1)) - resp, err := makeRequest(ctx, "PATCH", requestURL, headers, r, regOpts) + resp, err := makeRequest(ctx, "PATCH", requestURL, headers, io.TeeReader(sectionReader, &pw), regOpts) if err != nil && !errors.Is(err, io.EOF) { fn(api.ProgressResponse{ Status: fmt.Sprintf("error uploading chunk: %v", err), @@ -121,7 +95,7 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r switch { case resp.StatusCode == http.StatusUnauthorized: - ch <- errors.New("unauthorized") + errStatus = errors.New("unauthorized") auth := resp.Header.Get("www-authenticate") authRedir := ParseAuthRedirectString(auth) @@ -131,7 +105,9 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r } regOpts.Token = token - sectionReader = io.NewSectionReader(f, int64(offset), chunk) + + pw.completed = int(offset) + sectionReader = io.NewSectionReader(f, offset, chunk) continue case resp.StatusCode >= http.StatusBadRequest: body, _ := io.ReadAll(resp.Body) @@ -146,6 +122,10 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r break } + + if errStatus != nil { + return fmt.Errorf("max retries exceeded: %w", errStatus) + } } values := requestURL.Query() @@ -170,3 +150,32 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r } return nil } + +type ProgressWriter struct { + status string + digest string + bucket int + completed int + total int + fn func(api.ProgressResponse) +} + +func (pw *ProgressWriter) Write(b []byte) (int, error) { + n := len(b) + pw.bucket += n + pw.completed += n + + // throttle status updates to not spam the client + if pw.bucket >= 1024*1024 || pw.completed >= pw.total { + pw.fn(api.ProgressResponse{ + Status: pw.status, + Digest: pw.digest, + Total: pw.total, + Completed: pw.completed, + }) + + pw.bucket = 0 + } + + return n, nil +}