server: fix race conditions during download (#5994)

This fixes various data races scattered throughout the download/pull
client where the client was accessing the download state concurrently.

This commit is mostly a hot-fix and will be replaced by a new client one
day soon.

Also, remove the unnecessary opts argument from downloadChunk.
This commit is contained in:
Blake Mizerany 2024-07-26 14:24:24 -07:00 committed by GitHub
parent ec4c35fe99
commit 750c1c55f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -44,7 +44,7 @@ type blobDownload struct {
context.CancelFunc context.CancelFunc
done bool done chan struct{}
err error err error
references atomic.Int32 references atomic.Int32
} }
@ -53,7 +53,9 @@ type blobDownloadPart struct {
N int N int
Offset int64 Offset int64
Size int64 Size int64
Completed int64 Completed atomic.Int64
lastUpdatedMu sync.Mutex
lastUpdated time.Time lastUpdated time.Time
*blobDownload `json:"-"` *blobDownload `json:"-"`
@ -72,7 +74,7 @@ func (p *blobDownloadPart) Name() string {
} }
func (p *blobDownloadPart) StartsAt() int64 { func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed return p.Offset + p.Completed.Load()
} }
func (p *blobDownloadPart) StopsAt() int64 { func (p *blobDownloadPart) StopsAt() int64 {
@ -82,7 +84,9 @@ func (p *blobDownloadPart) StopsAt() int64 {
func (p *blobDownloadPart) Write(b []byte) (n int, err error) { func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b) n = len(b)
p.blobDownload.Completed.Add(int64(n)) p.blobDownload.Completed.Add(int64(n))
p.lastUpdatedMu.Lock()
p.lastUpdated = time.Now() p.lastUpdated = time.Now()
p.lastUpdatedMu.Unlock()
return n, nil return n, nil
} }
@ -92,6 +96,8 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
return err return err
} }
b.done = make(chan struct{})
for _, partFilePath := range partFilePaths { for _, partFilePath := range partFilePaths {
part, err := b.readPart(partFilePath) part, err := b.readPart(partFilePath)
if err != nil { if err != nil {
@ -99,7 +105,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
} }
b.Total += part.Size b.Total += part.Size
b.Completed.Add(part.Completed) b.Completed.Add(part.Completed.Load())
b.Parts = append(b.Parts, part) b.Parts = append(b.Parts, part)
} }
@ -139,6 +145,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
} }
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
defer close(b.done)
b.err = b.run(ctx, requestURL, opts) b.err = b.run(ctx, requestURL, opts)
} }
@ -230,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
g.SetLimit(numDownloadParts) g.SetLimit(numDownloadParts)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed == part.Size { if part.Completed.Load() == part.Size {
continue continue
} }
@ -238,7 +245,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt()) w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part, opts) err = b.downloadChunk(inner, directURL, w, part)
switch { switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
@ -279,29 +286,31 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err return err
} }
b.done = true
return nil return nil
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
g.Go(func() error { g.Go(func() error {
headers := make(http.Header) req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) if err != nil {
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed) n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)
return err return err
} }
part.Completed += n part.Completed.Add(n)
if err := b.writePart(part.Name(), part); err != nil { if err := b.writePart(part.Name(), part); err != nil {
return err return err
} }
@ -315,15 +324,21 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if part.Completed >= part.Size { if part.Completed.Load() >= part.Size {
return nil return nil
} }
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second { part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated // reset last updated
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{} part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled return errPartStalled
} }
case <-ctx.Done(): case <-ctx.Done():
@ -388,6 +403,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
ticker := time.NewTicker(60 * time.Millisecond) ticker := time.NewTicker(60 * time.Millisecond)
for { for {
select { select {
case <-b.done:
return b.err
case <-ticker.C: case <-ticker.C:
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
@ -395,10 +412,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
Total: b.Total, Total: b.Total,
Completed: b.Completed.Load(), Completed: b.Completed.Load(),
}) })
if b.done || b.err != nil {
return b.err
}
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
} }