server: fix race conditions during download (#5994)
This fixes various data races scattered throughout the download/pull client where the client was accessing the download state concurrently. This commit is mostly a hot-fix and will be replaced by a new client one day soon. Also, remove the unnecessary opts argument from downloadChunk.
This commit is contained in:
parent
ec4c35fe99
commit
750c1c55f7
1 changed files with 36 additions and 23 deletions
|
@ -44,17 +44,19 @@ type blobDownload struct {
|
|||
|
||||
context.CancelFunc
|
||||
|
||||
done bool
|
||||
done chan struct{}
|
||||
err error
|
||||
references atomic.Int32
|
||||
}
|
||||
|
||||
type blobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
lastUpdated time.Time
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed atomic.Int64
|
||||
|
||||
lastUpdatedMu sync.Mutex
|
||||
lastUpdated time.Time
|
||||
|
||||
*blobDownload `json:"-"`
|
||||
}
|
||||
|
@ -72,7 +74,7 @@ func (p *blobDownloadPart) Name() string {
|
|||
}
|
||||
|
||||
func (p *blobDownloadPart) StartsAt() int64 {
|
||||
return p.Offset + p.Completed
|
||||
return p.Offset + p.Completed.Load()
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StopsAt() int64 {
|
||||
|
@ -82,7 +84,9 @@ func (p *blobDownloadPart) StopsAt() int64 {
|
|||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
p.lastUpdatedMu.Lock()
|
||||
p.lastUpdated = time.Now()
|
||||
p.lastUpdatedMu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
@ -92,6 +96,8 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
|||
return err
|
||||
}
|
||||
|
||||
b.done = make(chan struct{})
|
||||
|
||||
for _, partFilePath := range partFilePaths {
|
||||
part, err := b.readPart(partFilePath)
|
||||
if err != nil {
|
||||
|
@ -99,7 +105,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
|||
}
|
||||
|
||||
b.Total += part.Size
|
||||
b.Completed.Add(part.Completed)
|
||||
b.Completed.Add(part.Completed.Load())
|
||||
b.Parts = append(b.Parts, part)
|
||||
}
|
||||
|
||||
|
@ -139,6 +145,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
|||
}
|
||||
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
||||
defer close(b.done)
|
||||
b.err = b.run(ctx, requestURL, opts)
|
||||
}
|
||||
|
||||
|
@ -230,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
g.SetLimit(numDownloadParts)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed == part.Size {
|
||||
if part.Completed.Load() == part.Size {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -238,7 +245,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, directURL, w, part, opts)
|
||||
err = b.downloadChunk(inner, directURL, w, part)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
|
@ -279,29 +286,31 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return err
|
||||
}
|
||||
|
||||
b.done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed += n
|
||||
part.Completed.Add(n)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -315,15 +324,21 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
|||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if part.Completed >= part.Size {
|
||||
if part.Completed.Load() >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
part.lastUpdatedMu.Lock()
|
||||
lastUpdated := part.lastUpdated
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
||||
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Time{}
|
||||
part.lastUpdatedMu.Unlock()
|
||||
return errPartStalled
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
@ -388,6 +403,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
|||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-b.done:
|
||||
return b.err
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
|
@ -395,10 +412,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
|||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue