From 917bd6108458990b554a8ff8f6535a861fd37f0e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jan 2024 14:18:45 -0800 Subject: [PATCH 1/6] refactor download run --- server/download.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/server/download.go b/server/download.go index f6d199b9..d31797da 100644 --- a/server/download.go +++ b/server/download.go @@ -138,16 +138,13 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r } func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { - b.err = b.run(ctx, requestURL, opts) -} - -func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644) if err != nil { - return err + b.err = err + return } defer file.Close() @@ -188,26 +185,30 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis } if err := g.Wait(); err != nil { - return err + b.err = err + return } // explicitly close the file so we can rename it if err := file.Close(); err != nil { - return err + b.err = err + return } for i := range b.Parts { if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil { - return err + b.err = err + return } } if err := os.Rename(file.Name(), b.Name); err != nil { - return err + b.err = err + return } b.done = true - return nil + return } func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { From 0de12368a0f5c96c26e5411311f95b39c02f0df2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jan 2024 13:50:03 -0800 Subject: [PATCH 2/6] add new LimitGroup for dynamic concurrency --- server/download.go | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/server/download.go b/server/download.go index d31797da..4985fb53 100644 --- a/server/download.go +++ b/server/download.go @@ -20,6 +20,7 @@ import ( "time" "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" @@ -150,8 +151,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis _ = file.Truncate(b.Total) - g, inner := errgroup.WithContext(ctx) - g.SetLimit(numDownloadParts) + g, inner := NewLimitGroup(ctx, numDownloadParts) for i := range b.Parts { part := b.Parts[i] if part.Completed == part.Size { @@ -378,3 +378,41 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { return download.Wait(ctx, opts.fn) } + +type LimitGroup struct { + *errgroup.Group + context.Context + Semaphore *semaphore.Weighted + + weight, max_weight int64 +} + +func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) { + g, ctx := errgroup.WithContext(ctx) + return &LimitGroup{ + Group: g, + Context: ctx, + Semaphore: semaphore.NewWeighted(n), + weight: n, + max_weight: n, + }, ctx +} + +func (g *LimitGroup) Go(fn func() error) { + weight := g.weight + g.Semaphore.Acquire(g.Context, weight) + if g.Context.Err() != nil { + return + } + + g.Group.Go(func() error { + defer g.Semaphore.Release(weight) + return fn() + }) +} + +func (g *LimitGroup) SetLimit(n int64) { + if n > 0 { + g.weight = g.max_weight / n + } +} From 074934be030305f5e6c3743d0ecd0ee6fd172de5 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jan 2024 14:35:28 -0800 Subject: [PATCH 3/6] adjust group limit based on download speed --- server/download.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/server/download.go b/server/download.go index 4985fb53..d8a841e1 100644 --- a/server/download.go +++ b/server/download.go @@ -152,6 +152,36 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis _ = file.Truncate(b.Total) g, inner := NewLimitGroup(ctx, numDownloadParts) + + go func() { + ticker := time.NewTicker(time.Second) + var n int64 = 1 + var maxDelta float64 + var buckets []int64 + for { + select { + case <-ticker.C: + buckets = append(buckets, b.Completed.Load()) + if len(buckets) < 2 { + continue + } else if len(buckets) > 10 { + buckets = buckets[1:] + } + + delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets)) + slog.Debug(fmt.Sprintf("delta: %s/s max_delta: %s/s", format.HumanBytes(int64(delta)), format.HumanBytes(int64(maxDelta)))) + if delta > maxDelta*1.5 { + maxDelta = delta + g.SetLimit(n) + n++ + } + + case <-ctx.Done(): + return + } + } + }() + for i := range b.Parts { part := b.Parts[i] if part.Completed == part.Size { @@ -413,6 +443,7 @@ func (g *LimitGroup) Go(fn func() error) { func (g *LimitGroup) SetLimit(n int64) { if n > 0 { + slog.Debug(fmt.Sprintf("setting limit to %d", n)) g.weight = g.max_weight / n } } From bea007deb7bf3ba013aa503c9d9e3f23074c76ac Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jan 2024 15:10:45 -0800 Subject: [PATCH 4/6] use LimitGroup for uploads --- server/upload.go | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/server/upload.go b/server/upload.go index 4da34052..b090721e 100644 --- a/server/upload.go +++ b/server/upload.go @@ -18,7 +18,6 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" - "golang.org/x/sync/errgroup" ) var blobUploadManager sync.Map @@ -137,8 +136,37 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { } defer b.file.Close() - g, inner := errgroup.WithContext(ctx) - g.SetLimit(numUploadParts) + g, inner := NewLimitGroup(ctx, numUploadParts) + + go func() { + ticker := time.NewTicker(time.Second) + var n int64 = 1 + var maxDelta float64 + var buckets []int64 + for { + select { + case <-ticker.C: + buckets = append(buckets, b.Completed.Load()) + if len(buckets) < 2 { + continue + } else if len(buckets) > 10 { + buckets = buckets[1:] + } + + delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets)) + slog.Debug(fmt.Sprintf("delta: %s/s max_delta: %s/s", format.HumanBytes(int64(delta)), format.HumanBytes(int64(maxDelta)))) + if delta > maxDelta*1.5 { + maxDelta = delta + g.SetLimit(n) + n++ + } + + case <-ctx.Done(): + return + } + } + }() + for i := range b.Parts { part := &b.Parts[i] select { From 6a4b994433962ae95179c96f15c9888f14c97526 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 29 Jan 2024 17:01:31 -0800 Subject: [PATCH 5/6] lint --- server/download.go | 11 ++++------- server/upload.go | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/server/download.go b/server/download.go index d8a841e1..5e3217f5 100644 --- a/server/download.go +++ b/server/download.go @@ -188,7 +188,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis continue } - g.Go(func() error { + g.Go(inner, func() error { var err error for try := 0; try < maxRetries; try++ { w := io.NewOffsetWriter(file, part.StartsAt()) @@ -238,7 +238,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis } b.done = true - return } func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { @@ -411,7 +410,6 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { type LimitGroup struct { *errgroup.Group - context.Context Semaphore *semaphore.Weighted weight, max_weight int64 @@ -421,17 +419,16 @@ func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) g, ctx := errgroup.WithContext(ctx) return &LimitGroup{ Group: g, - Context: ctx, Semaphore: semaphore.NewWeighted(n), weight: n, max_weight: n, }, ctx } -func (g *LimitGroup) Go(fn func() error) { +func (g *LimitGroup) Go(ctx context.Context, fn func() error) { weight := g.weight - g.Semaphore.Acquire(g.Context, weight) - if g.Context.Err() != nil { + _ = g.Semaphore.Acquire(ctx, weight) + if ctx.Err() != nil { return } diff --git a/server/upload.go b/server/upload.go index b090721e..590247d9 100644 --- a/server/upload.go +++ b/server/upload.go @@ -172,7 +172,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { select { case <-inner.Done(): case requestURL := <-b.nextURL: - g.Go(func() error { + g.Go(inner, func() error { var err error for try := 0; try < maxRetries; try++ { err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts) From 084d8466216cc1f6ad1b00a16f309bcf07fc08bf Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 29 Jan 2024 17:16:37 -0800 Subject: [PATCH 6/6] refactor --- server/download.go | 110 ++++++++++++++++++++++++++------------------- server/upload.go | 34 ++------------ 2 files changed, 69 insertions(+), 75 deletions(-) diff --git a/server/download.go b/server/download.go index 5e3217f5..a50c0cd1 100644 --- a/server/download.go +++ b/server/download.go @@ -151,36 +151,9 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis _ = file.Truncate(b.Total) - g, inner := NewLimitGroup(ctx, numDownloadParts) - - go func() { - ticker := time.NewTicker(time.Second) - var n int64 = 1 - var maxDelta float64 - var buckets []int64 - for { - select { - case <-ticker.C: - buckets = append(buckets, b.Completed.Load()) - if len(buckets) < 2 { - continue - } else if len(buckets) > 10 { - buckets = buckets[1:] - } - - delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets)) - slog.Debug(fmt.Sprintf("delta: %s/s max_delta: %s/s", format.HumanBytes(int64(delta)), format.HumanBytes(int64(maxDelta)))) - if delta > maxDelta*1.5 { - maxDelta = delta - g.SetLimit(n) - n++ - } - - case <-ctx.Done(): - return - } - } - }() + var limit int64 = 2 + g, inner := NewLimitGroup(ctx, numDownloadParts, limit) + go watchDelta(inner, g, &b.Completed, limit) for i := range b.Parts { part := b.Parts[i] @@ -410,37 +383,84 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { type LimitGroup struct { *errgroup.Group - Semaphore *semaphore.Weighted - - weight, max_weight int64 + *semaphore.Weighted + size, limit int64 } -func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) { +func NewLimitGroup(ctx context.Context, size, limit int64) (*LimitGroup, context.Context) { g, ctx := errgroup.WithContext(ctx) return &LimitGroup{ - Group: g, - Semaphore: semaphore.NewWeighted(n), - weight: n, - max_weight: n, + Group: g, + Weighted: semaphore.NewWeighted(size), + size: size, + limit: limit, }, ctx } func (g *LimitGroup) Go(ctx context.Context, fn func() error) { - weight := g.weight - _ = g.Semaphore.Acquire(ctx, weight) + var weight int64 = 1 + if g.limit > 0 { + weight = g.size / g.limit + } + + _ = g.Acquire(ctx, weight) if ctx.Err() != nil { return } g.Group.Go(func() error { - defer g.Semaphore.Release(weight) + defer g.Release(weight) return fn() }) } -func (g *LimitGroup) SetLimit(n int64) { - if n > 0 { - slog.Debug(fmt.Sprintf("setting limit to %d", n)) - g.weight = g.max_weight / n +func (g *LimitGroup) SetLimit(limit int64) { + if limit > g.limit { + g.limit = limit + } +} + +func watchDelta(ctx context.Context, g *LimitGroup, c *atomic.Int64, limit int64) { + var maxDelta float64 + var buckets []int64 + + // 5s ramp up period + nextUpdate := time.Now().Add(5 * time.Second) + + ticker := time.NewTicker(time.Second) + for { + select { + case <-ticker.C: + buckets = append(buckets, c.Load()) + if len(buckets) < 2 { + continue + } else if len(buckets) > 10 { + buckets = buckets[1:] + } + + delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets)) + slog.Debug("", "limit", limit, "delta", format.HumanBytes(int64(delta)), "max_delta", format.HumanBytes(int64(maxDelta))) + + if time.Now().Before(nextUpdate) { + // quiet period; do not update ccy if recently updated + continue + } else if maxDelta > 0 { + x := delta / maxDelta + if x < 1.2 { + continue + } + + limit += int64(x) + slog.Debug("setting", "limit", limit) + g.SetLimit(limit) + } + + // 3s cooldown period + nextUpdate = time.Now().Add(3 * time.Second) + maxDelta = delta + + case <-ctx.Done(): + return + } } } diff --git a/server/upload.go b/server/upload.go index 590247d9..c4517268 100644 --- a/server/upload.go +++ b/server/upload.go @@ -136,41 +136,15 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { } defer b.file.Close() - g, inner := NewLimitGroup(ctx, numUploadParts) - - go func() { - ticker := time.NewTicker(time.Second) - var n int64 = 1 - var maxDelta float64 - var buckets []int64 - for { - select { - case <-ticker.C: - buckets = append(buckets, b.Completed.Load()) - if len(buckets) < 2 { - continue - } else if len(buckets) > 10 { - buckets = buckets[1:] - } - - delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets)) - slog.Debug(fmt.Sprintf("delta: %s/s max_delta: %s/s", format.HumanBytes(int64(delta)), format.HumanBytes(int64(maxDelta)))) - if delta > maxDelta*1.5 { - maxDelta = delta - g.SetLimit(n) - n++ - } - - case <-ctx.Done(): - return - } - } - }() + var limit int64 = 2 + g, inner := NewLimitGroup(ctx, numUploadParts, limit) + go watchDelta(inner, g, &b.Completed, limit) for i := range b.Parts { part := &b.Parts[i] select { case <-inner.Done(): + break case requestURL := <-b.nextURL: g.Go(inner, func() error { var err error