From c8af3c2d969a99618eecf169bd75aa112573ac27 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Thu, 25 Jul 2024 15:58:30 -0700 Subject: [PATCH] server: reuse original download URL for images (#5962) This changes the registry client to reuse the original download URL it gets on the first redirect response for all subsequent requests, preventing thundering herd issues when hot new LLMs are released. --- server/download.go | 75 +++++++++++++++++++++++++++++++++++++++++++++- server/images.go | 6 +++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/server/download.go b/server/download.go index d93cd3b4..8b5b577f 100644 --- a/server/download.go +++ b/server/download.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "math" + "math/rand/v2" "net/http" "net/url" "os" @@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis b.err = b.run(ctx, requestURL, opts) } +func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error { + var n int + return func(ctx context.Context) error { + if ctx.Err() != nil { + return ctx.Err() + } + + n++ + + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } + } +} + func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis _ = file.Truncate(b.Total) + directURL, err := func() (*url.URL, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + backoff := newBackoff(10 * time.Second) + for { + // shallow clone opts to be used in the closure + // without affecting the outer opts. + newOpts := new(registryOptions) + *newOpts = *opts + + newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) > 10 { + return errors.New("maxium redirects exceeded (10) for directURL") + } + + // if the hostname is the same, allow the redirect + if req.URL.Hostname() == requestURL.Hostname() { + return nil + } + + // stop at the first redirect that is not + // the same hostname as the original + // request. + return http.ErrUseLastResponse + } + + resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts) + if err != nil { + slog.Warn("failed to get direct URL; backing off and retrying", "err", err) + if err := backoff(ctx); err != nil { + return nil, err + } + continue + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusTemporaryRedirect { + return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + return resp.Location() + } + }() + if err != nil { + return err + } + g, inner := errgroup.WithContext(ctx) g.SetLimit(numDownloadParts) for i := range b.Parts { @@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis var err error for try := 0; try < maxRetries; try++ { w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, requestURL, w, part, opts) + err = b.downloadChunk(inner, directURL, w, part, opts) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): // return immediately if the context is canceled or the device is out of space diff --git a/server/images.go b/server/images.go index 574dec19..836dbcc2 100644 --- a/server/images.go +++ b/server/images.go @@ -54,6 +54,8 @@ type registryOptions struct { Username string Password string Token string + + CheckRedirect func(req *http.Request, via []*http.Request) error } type Model struct { @@ -1131,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header req.ContentLength = contentLength } - resp, err := http.DefaultClient.Do(req) + resp, err := (&http.Client{ + CheckRedirect: regOpts.CheckRedirect, + }).Do(req) if err != nil { return nil, err }