From 8228d166cef443cd23bad141537101c5acf20080 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 1 Aug 2023 15:34:52 -0400 Subject: [PATCH] pr comments --- server/download.go | 95 ++++++++++++++++++++++------------------------ 1 file changed, 46 insertions(+), 49 deletions(-) diff --git a/server/download.go b/server/download.go index d7fd0006..7aad599d 100644 --- a/server/download.go +++ b/server/download.go @@ -55,11 +55,7 @@ func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *Reg // this is another client requesting the server to download the same blob concurrently return monitorDownload(ctx, mp, regOpts, fileDownload, fn) } - resp, err := requestDownload(ctx, mp, regOpts, fileDownload) - if err != nil { - return err - } - return doDownload(ctx, fileDownload, resp, fn) + return doDownload(ctx, mp, regOpts, fileDownload, fn) } var downloadMu sync.Mutex // mutex to check to resume a download while monitoring @@ -68,49 +64,55 @@ var downloadMu sync.Mutex // mutex to check to resume a download while monitorin func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { tick := time.NewTicker(time.Second) for range tick.C { - downloadMu.Lock() - val, downloading := inProgress.Load(f.Digest) - if !downloading { - // check once again if the download is complete - if fi, _ := os.Stat(f.FilePath); fi != nil { - downloadMu.Unlock() - // successfull download while monitoring - fn(api.ProgressResponse{ - Digest: f.Digest, - Total: int(fi.Size()), - Completed: int(fi.Size()), - }) - return nil + done, resume, err := func() (bool, bool, error) { + downloadMu.Lock() + defer downloadMu.Unlock() + val, downloading := inProgress.Load(f.Digest) + if !downloading { + // check once again if the download is complete + if fi, _ := os.Stat(f.FilePath); fi != nil { + // successful download while monitoring + fn(api.ProgressResponse{ + Digest: f.Digest, + Total: int(fi.Size()), + Completed: int(fi.Size()), + }) + return true, false, nil + } + // resume the download + inProgress.Store(f.Digest, f) // store the file download again to claim the resume + return false, true, nil } - // resume the download - resp, err := requestDownload(ctx, mp, regOpts, f) - if err != nil { - downloadMu.Unlock() - return fmt.Errorf("resume: %w", err) + f, ok := val.(*FileDownload) + if !ok { + return false, false, fmt.Errorf("invalid type for in progress download: %T", val) } - inProgress.Store(f.Digest, f) - downloadMu.Unlock() - return doDownload(ctx, f, resp, fn) + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", f.Digest), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + return false, false, nil + }() + if err != nil { + return err } - downloadMu.Unlock() - f, ok := val.(*FileDownload) - if !ok { - return fmt.Errorf("invalid type for in progress download: %T", val) + if done { + // done downloading + return nil + } + if resume { + return doDownload(ctx, mp, regOpts, f, fn) } - fn(api.ProgressResponse{ - Status: fmt.Sprintf("downloading %s", f.Digest), - Digest: f.Digest, - Total: int(f.Total), - Completed: int(f.Completed), - }) } return nil } var chunkSize = 1024 * 1024 // 1 MiB in bytes -// requestDownload requests a blob from the registry and returns the response -func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) { +// doDownload downloads a blob from the registry and stores it in the blobs directory +func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { var size int64 fi, err := os.Stat(f.FilePath + "-partial") @@ -118,7 +120,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions case errors.Is(err, os.ErrNotExist): // noop, file doesn't exist so create it case err != nil: - return nil, fmt.Errorf("stat: %w", err) + return fmt.Errorf("stat: %w", err) default: size = fi.Size() // Ensure the size is divisible by the chunk size by removing excess bytes @@ -126,7 +128,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions err := os.Truncate(f.FilePath+"-partial", size) if err != nil { - return nil, fmt.Errorf("truncate: %w", err) + return fmt.Errorf("truncate: %w", err) } } @@ -138,18 +140,18 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions resp, err := makeRequest("GET", url, headers, nil, regOpts) if err != nil { log.Printf("couldn't download blob: %v", err) - return nil, err + return err } - // resp MUST be closed by doDownload, which should follow this function + defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) + return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) } err = os.MkdirAll(path.Dir(f.FilePath), 0o700) if err != nil { - return nil, fmt.Errorf("make blobs directory: %w", err) + return fmt.Errorf("make blobs directory: %w", err) } remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) @@ -157,12 +159,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions f.Total = remaining + f.Completed inProgress.Store(f.Digest, f) - return resp, nil -} -// doDownload downloads a blob from the registry and stores it in the blobs directory -func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error { - defer resp.Body.Close() out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { return fmt.Errorf("open file: %w", err)