diff --git a/server/download.go b/server/download.go index a50c0cd1..f6d199b9 100644 --- a/server/download.go +++ b/server/download.go @@ -20,7 +20,6 @@ import ( "time" "golang.org/x/sync/errgroup" - "golang.org/x/sync/semaphore" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" @@ -139,29 +138,30 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r } func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { + b.err = b.run(ctx, requestURL, opts) +} + +func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644) if err != nil { - b.err = err - return + return err } defer file.Close() _ = file.Truncate(b.Total) - var limit int64 = 2 - g, inner := NewLimitGroup(ctx, numDownloadParts, limit) - go watchDelta(inner, g, &b.Completed, limit) - + g, inner := errgroup.WithContext(ctx) + g.SetLimit(numDownloadParts) for i := range b.Parts { part := b.Parts[i] if part.Completed == part.Size { continue } - g.Go(inner, func() error { + g.Go(func() error { var err error for try := 0; try < maxRetries; try++ { w := io.NewOffsetWriter(file, part.StartsAt()) @@ -188,29 +188,26 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis } if err := g.Wait(); err != nil { - b.err = err - return + return err } // explicitly close the file so we can rename it if err := file.Close(); err != nil { - b.err = err - return + return err } for i := range b.Parts { if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil { - b.err = err - return + return err } } if err := os.Rename(file.Name(), b.Name); err != nil { - b.err = err - return + return err } b.done = true + return nil } func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { @@ -380,87 +377,3 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { return download.Wait(ctx, opts.fn) } - -type LimitGroup struct { - *errgroup.Group - *semaphore.Weighted - size, limit int64 -} - -func NewLimitGroup(ctx context.Context, size, limit int64) (*LimitGroup, context.Context) { - g, ctx := errgroup.WithContext(ctx) - return &LimitGroup{ - Group: g, - Weighted: semaphore.NewWeighted(size), - size: size, - limit: limit, - }, ctx -} - -func (g *LimitGroup) Go(ctx context.Context, fn func() error) { - var weight int64 = 1 - if g.limit > 0 { - weight = g.size / g.limit - } - - _ = g.Acquire(ctx, weight) - if ctx.Err() != nil { - return - } - - g.Group.Go(func() error { - defer g.Release(weight) - return fn() - }) -} - -func (g *LimitGroup) SetLimit(limit int64) { - if limit > g.limit { - g.limit = limit - } -} - -func watchDelta(ctx context.Context, g *LimitGroup, c *atomic.Int64, limit int64) { - var maxDelta float64 - var buckets []int64 - - // 5s ramp up period - nextUpdate := time.Now().Add(5 * time.Second) - - ticker := time.NewTicker(time.Second) - for { - select { - case <-ticker.C: - buckets = append(buckets, c.Load()) - if len(buckets) < 2 { - continue - } else if len(buckets) > 10 { - buckets = buckets[1:] - } - - delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets)) - slog.Debug("", "limit", limit, "delta", format.HumanBytes(int64(delta)), "max_delta", format.HumanBytes(int64(maxDelta))) - - if time.Now().Before(nextUpdate) { - // quiet period; do not update ccy if recently updated - continue - } else if maxDelta > 0 { - x := delta / maxDelta - if x < 1.2 { - continue - } - - limit += int64(x) - slog.Debug("setting", "limit", limit) - g.SetLimit(limit) - } - - // 3s cooldown period - nextUpdate = time.Now().Add(3 * time.Second) - maxDelta = delta - - case <-ctx.Done(): - return - } - } -} diff --git a/server/upload.go b/server/upload.go index c4517268..4da34052 100644 --- a/server/upload.go +++ b/server/upload.go @@ -18,6 +18,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" + "golang.org/x/sync/errgroup" ) var blobUploadManager sync.Map @@ -136,17 +137,14 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { } defer b.file.Close() - var limit int64 = 2 - g, inner := NewLimitGroup(ctx, numUploadParts, limit) - go watchDelta(inner, g, &b.Completed, limit) - + g, inner := errgroup.WithContext(ctx) + g.SetLimit(numUploadParts) for i := range b.Parts { part := &b.Parts[i] select { case <-inner.Done(): - break case requestURL := <-b.nextURL: - g.Go(inner, func() error { + g.Go(func() error { var err error for try := 0; try < maxRetries; try++ { err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)