server: fix race conditions during download (#5994)
This fixes various data races scattered throughout the download/pull client where the client was accessing the download state concurrently. This commit is mostly a hot-fix and will be replaced by a new client one day soon. Also, remove the unnecessary opts argument from downloadChunk.
This commit is contained in:
parent
ec4c35fe99
commit
750c1c55f7
1 changed files with 36 additions and 23 deletions
|
@ -44,17 +44,19 @@ type blobDownload struct {
|
||||||
|
|
||||||
context.CancelFunc
|
context.CancelFunc
|
||||||
|
|
||||||
done bool
|
done chan struct{}
|
||||||
err error
|
err error
|
||||||
references atomic.Int32
|
references atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
type blobDownloadPart struct {
|
type blobDownloadPart struct {
|
||||||
N int
|
N int
|
||||||
Offset int64
|
Offset int64
|
||||||
Size int64
|
Size int64
|
||||||
Completed int64
|
Completed atomic.Int64
|
||||||
lastUpdated time.Time
|
|
||||||
|
lastUpdatedMu sync.Mutex
|
||||||
|
lastUpdated time.Time
|
||||||
|
|
||||||
*blobDownload `json:"-"`
|
*blobDownload `json:"-"`
|
||||||
}
|
}
|
||||||
|
@ -72,7 +74,7 @@ func (p *blobDownloadPart) Name() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *blobDownloadPart) StartsAt() int64 {
|
func (p *blobDownloadPart) StartsAt() int64 {
|
||||||
return p.Offset + p.Completed
|
return p.Offset + p.Completed.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *blobDownloadPart) StopsAt() int64 {
|
func (p *blobDownloadPart) StopsAt() int64 {
|
||||||
|
@ -82,7 +84,9 @@ func (p *blobDownloadPart) StopsAt() int64 {
|
||||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||||
n = len(b)
|
n = len(b)
|
||||||
p.blobDownload.Completed.Add(int64(n))
|
p.blobDownload.Completed.Add(int64(n))
|
||||||
|
p.lastUpdatedMu.Lock()
|
||||||
p.lastUpdated = time.Now()
|
p.lastUpdated = time.Now()
|
||||||
|
p.lastUpdatedMu.Unlock()
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,6 +96,8 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b.done = make(chan struct{})
|
||||||
|
|
||||||
for _, partFilePath := range partFilePaths {
|
for _, partFilePath := range partFilePaths {
|
||||||
part, err := b.readPart(partFilePath)
|
part, err := b.readPart(partFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -99,7 +105,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Total += part.Size
|
b.Total += part.Size
|
||||||
b.Completed.Add(part.Completed)
|
b.Completed.Add(part.Completed.Load())
|
||||||
b.Parts = append(b.Parts, part)
|
b.Parts = append(b.Parts, part)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,6 +145,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
||||||
|
defer close(b.done)
|
||||||
b.err = b.run(ctx, requestURL, opts)
|
b.err = b.run(ctx, requestURL, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||||
g.SetLimit(numDownloadParts)
|
g.SetLimit(numDownloadParts)
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
if part.Completed == part.Size {
|
if part.Completed.Load() == part.Size {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,7 +245,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||||
var err error
|
var err error
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := 0; try < maxRetries; try++ {
|
||||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||||
err = b.downloadChunk(inner, directURL, w, part, opts)
|
err = b.downloadChunk(inner, directURL, w, part)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||||
// return immediately if the context is canceled or the device is out of space
|
// return immediately if the context is canceled or the device is out of space
|
||||||
|
@ -279,29 +286,31 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.done = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
headers := make(http.Header)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
if err != nil {
|
||||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
|
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
// rollback progress
|
// rollback progress
|
||||||
b.Completed.Add(-n)
|
b.Completed.Add(-n)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
part.Completed += n
|
part.Completed.Add(n)
|
||||||
if err := b.writePart(part.Name(), part); err != nil {
|
if err := b.writePart(part.Name(), part); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -315,15 +324,21 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if part.Completed >= part.Size {
|
if part.Completed.Load() >= part.Size {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
part.lastUpdatedMu.Lock()
|
||||||
|
lastUpdated := part.lastUpdated
|
||||||
|
part.lastUpdatedMu.Unlock()
|
||||||
|
|
||||||
|
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
|
||||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||||
// reset last updated
|
// reset last updated
|
||||||
|
part.lastUpdatedMu.Lock()
|
||||||
part.lastUpdated = time.Time{}
|
part.lastUpdated = time.Time{}
|
||||||
|
part.lastUpdatedMu.Unlock()
|
||||||
return errPartStalled
|
return errPartStalled
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -388,6 +403,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||||
ticker := time.NewTicker(60 * time.Millisecond)
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-b.done:
|
||||||
|
return b.err
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
fn(api.ProgressResponse{
|
fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||||
|
@ -395,10 +412,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||||
Total: b.Total,
|
Total: b.Total,
|
||||||
Completed: b.Completed.Load(),
|
Completed: b.Completed.Load(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if b.done || b.err != nil {
|
|
||||||
return b.err
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue