diff --git a/server/download.go b/server/download.go index d93cd3b4..8b5b577f 100644 --- a/server/download.go +++ b/server/download.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "math" + "math/rand/v2" "net/http" "net/url" "os" @@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis b.err = b.run(ctx, requestURL, opts) } +func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error { + var n int + return func(ctx context.Context) error { + if ctx.Err() != nil { + return ctx.Err() + } + + n++ + + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } + } +} + func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis _ = file.Truncate(b.Total) + directURL, err := func() (*url.URL, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + backoff := newBackoff(10 * time.Second) + for { + // shallow clone opts to be used in the closure + // without affecting the outer opts. + newOpts := new(registryOptions) + *newOpts = *opts + + newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) > 10 { + return errors.New("maxium redirects exceeded (10) for directURL") + } + + // if the hostname is the same, allow the redirect + if req.URL.Hostname() == requestURL.Hostname() { + return nil + } + + // stop at the first redirect that is not + // the same hostname as the original + // request. + return http.ErrUseLastResponse + } + + resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts) + if err != nil { + slog.Warn("failed to get direct URL; backing off and retrying", "err", err) + if err := backoff(ctx); err != nil { + return nil, err + } + continue + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusTemporaryRedirect { + return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + return resp.Location() + } + }() + if err != nil { + return err + } + g, inner := errgroup.WithContext(ctx) g.SetLimit(numDownloadParts) for i := range b.Parts { @@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis var err error for try := 0; try < maxRetries; try++ { w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, requestURL, w, part, opts) + err = b.downloadChunk(inner, directURL, w, part, opts) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): // return immediately if the context is canceled or the device is out of space diff --git a/server/images.go b/server/images.go index 574dec19..836dbcc2 100644 --- a/server/images.go +++ b/server/images.go @@ -54,6 +54,8 @@ type registryOptions struct { Username string Password string Token string + + CheckRedirect func(req *http.Request, via []*http.Request) error } type Model struct { @@ -1131,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header req.ContentLength = contentLength } - resp, err := http.DefaultClient.Do(req) + resp, err := (&http.Client{ + CheckRedirect: regOpts.CheckRedirect, + }).Do(req) if err != nil { return nil, err }