diff --git a/server/download.go b/server/download.go index 72aa0d77..246d300d 100644 --- a/server/download.go +++ b/server/download.go @@ -25,17 +25,27 @@ type FileDownload struct { var inProgress sync.Map // map of digests currently being downloaded to their current download progress +type downloadOpts struct { + mp ModelPath + digest string + regOpts *RegistryOptions + fn func(api.ProgressResponse) + retry int // track the number of retries on this download +} + +const maxRetry = 3 + // downloadBlob downloads a blob from the registry and stores it in the blobs directory -func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - fp, err := GetBlobsPath(digest) +func downloadBlob(ctx context.Context, opts downloadOpts) error { + fp, err := GetBlobsPath(opts.digest) if err != nil { return err } if fi, _ := os.Stat(fp); fi != nil { // we already have the file, so return - fn(api.ProgressResponse{ - Digest: digest, + opts.fn(api.ProgressResponse{ + Digest: opts.digest, Total: int(fi.Size()), Completed: int(fi.Size()), }) @@ -44,24 +54,33 @@ func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *Reg } fileDownload := &FileDownload{ - Digest: digest, + Digest: opts.digest, FilePath: fp, Total: 1, // dummy value to indicate that we don't know the total size yet Completed: 0, } - _, downloading := inProgress.LoadOrStore(digest, fileDownload) + _, downloading := inProgress.LoadOrStore(opts.digest, fileDownload) if downloading { // this is another client requesting the server to download the same blob concurrently - return monitorDownload(ctx, mp, regOpts, fileDownload, fn) + return monitorDownload(ctx, opts, fileDownload) } - return doDownload(ctx, mp, regOpts, fileDownload, fn) + if err := doDownload(ctx, opts, fileDownload); err != nil { + if errors.Is(err, errDownload) && opts.retry < maxRetry { + opts.retry++ + log.Print(err) + log.Printf("retrying download of %s", opts.digest) + return downloadBlob(ctx, opts) + } + return err + } + return nil } var downloadMu sync.Mutex // mutex to check to resume a download while monitoring // monitorDownload monitors the download progress of a blob and resumes it if it is interrupted -func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { +func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error { tick := time.NewTicker(time.Second) for range tick.C { done, resume, err := func() (bool, bool, error) { @@ -72,7 +91,7 @@ func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions // check once again if the download is complete if fi, _ := os.Stat(f.FilePath); fi != nil { // successful download while monitoring - fn(api.ProgressResponse{ + opts.fn(api.ProgressResponse{ Digest: f.Digest, Total: int(fi.Size()), Completed: int(fi.Size()), @@ -87,7 +106,7 @@ func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions if !ok { return false, false, fmt.Errorf("invalid type for in progress download: %T", val) } - fn(api.ProgressResponse{ + opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("downloading %s", f.Digest), Digest: f.Digest, Total: int(f.Total), @@ -103,16 +122,19 @@ func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions return nil } if resume { - return doDownload(ctx, mp, regOpts, f, fn) + return doDownload(ctx, opts, f) } } return nil } -var chunkSize = 1024 * 1024 // 1 MiB in bytes +var ( + chunkSize = 1024 * 1024 // 1 MiB in bytes + errDownload = fmt.Errorf("download failed") +) // doDownload downloads a blob from the registry and stores it in the blobs directory -func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { +func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error { defer inProgress.Delete(f.Digest) var size int64 @@ -133,21 +155,21 @@ func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f * } } - url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), f.Digest) + url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest) headers := map[string]string{ "Range": fmt.Sprintf("bytes=%d-", size), } - resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts) if err != nil { log.Printf("couldn't download blob: %v", err) - return err + return fmt.Errorf("%w: %w", errDownload, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) + return fmt.Errorf("%w: on download registry responded with code %d: %v", errDownload, resp.StatusCode, string(body)) } err = os.MkdirAll(path.Dir(f.FilePath), 0o700) @@ -174,7 +196,7 @@ outerLoop: inProgress.Delete(f.Digest) return nil default: - fn(api.ProgressResponse{ + opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("downloading %s", f.Digest), Digest: f.Digest, Total: int(f.Total), @@ -187,7 +209,7 @@ outerLoop: } if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil { - fn(api.ProgressResponse{ + opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("error renaming file: %v", err), Digest: f.Digest, Total: int(f.Total), @@ -202,7 +224,7 @@ outerLoop: n, err := io.CopyN(out, resp.Body, int64(chunkSize)) if err != nil && !errors.Is(err, io.EOF) { - return err + return fmt.Errorf("%w: %w", errDownload, err) } f.Completed += n diff --git a/server/images.go b/server/images.go index b21451d9..73efc1c5 100644 --- a/server/images.go +++ b/server/images.go @@ -995,7 +995,14 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu layers = append(layers, &manifest.Config) for _, layer := range layers { - if err := downloadBlob(ctx, mp, layer.Digest, regOpts, fn); err != nil { + if err := downloadBlob( + ctx, + downloadOpts{ + mp: mp, + digest: layer.Digest, + regOpts: regOpts, + fn: fn, + }); err != nil { return err } }