Merge pull request #2221 from ollama/mxyng/up-down-ccy

adjust download and upload concurrency based on available bandwidth
This commit is contained in:
Michael Yang 2024-03-07 09:27:33 -08:00 committed by GitHub
commit 2e20110e50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 106 additions and 17 deletions

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
@ -138,30 +139,29 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
} }
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
b.err = b.run(ctx, requestURL, opts)
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
defer blobDownloadManager.Delete(b.Digest) defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644) file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
if err != nil { if err != nil {
return err b.err = err
return
} }
defer file.Close() defer file.Close()
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
g, inner := errgroup.WithContext(ctx) var limit int64 = 2
g.SetLimit(numDownloadParts) g, inner := NewLimitGroup(ctx, numDownloadParts, limit)
go watchDelta(inner, g, &b.Completed, limit)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed == part.Size { if part.Completed == part.Size {
continue continue
} }
g.Go(func() error { g.Go(inner, func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt()) w := io.NewOffsetWriter(file, part.StartsAt())
@ -188,26 +188,29 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
return err b.err = err
return
} }
// explicitly close the file so we can rename it // explicitly close the file so we can rename it
if err := file.Close(); err != nil { if err := file.Close(); err != nil {
return err b.err = err
return
} }
for i := range b.Parts { for i := range b.Parts {
if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil { if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
return err b.err = err
return
} }
} }
if err := os.Rename(file.Name(), b.Name); err != nil { if err := os.Rename(file.Name(), b.Name); err != nil {
return err b.err = err
return
} }
b.done = true b.done = true
return nil
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
@ -377,3 +380,87 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
return download.Wait(ctx, opts.fn) return download.Wait(ctx, opts.fn)
} }
type LimitGroup struct {
*errgroup.Group
*semaphore.Weighted
size, limit int64
}
func NewLimitGroup(ctx context.Context, size, limit int64) (*LimitGroup, context.Context) {
g, ctx := errgroup.WithContext(ctx)
return &LimitGroup{
Group: g,
Weighted: semaphore.NewWeighted(size),
size: size,
limit: limit,
}, ctx
}
func (g *LimitGroup) Go(ctx context.Context, fn func() error) {
var weight int64 = 1
if g.limit > 0 {
weight = g.size / g.limit
}
_ = g.Acquire(ctx, weight)
if ctx.Err() != nil {
return
}
g.Group.Go(func() error {
defer g.Release(weight)
return fn()
})
}
func (g *LimitGroup) SetLimit(limit int64) {
if limit > g.limit {
g.limit = limit
}
}
func watchDelta(ctx context.Context, g *LimitGroup, c *atomic.Int64, limit int64) {
var maxDelta float64
var buckets []int64
// 5s ramp up period
nextUpdate := time.Now().Add(5 * time.Second)
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
buckets = append(buckets, c.Load())
if len(buckets) < 2 {
continue
} else if len(buckets) > 10 {
buckets = buckets[1:]
}
delta := float64((buckets[len(buckets)-1] - buckets[0])) / float64(len(buckets))
slog.Debug("", "limit", limit, "delta", format.HumanBytes(int64(delta)), "max_delta", format.HumanBytes(int64(maxDelta)))
if time.Now().Before(nextUpdate) {
// quiet period; do not update ccy if recently updated
continue
} else if maxDelta > 0 {
x := delta / maxDelta
if x < 1.2 {
continue
}
limit += int64(x)
slog.Debug("setting", "limit", limit)
g.SetLimit(limit)
}
// 3s cooldown period
nextUpdate = time.Now().Add(3 * time.Second)
maxDelta = delta
case <-ctx.Done():
return
}
}
}

View file

@ -18,7 +18,6 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup"
) )
var blobUploadManager sync.Map var blobUploadManager sync.Map
@ -137,14 +136,17 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
} }
defer b.file.Close() defer b.file.Close()
g, inner := errgroup.WithContext(ctx) var limit int64 = 2
g.SetLimit(numUploadParts) g, inner := NewLimitGroup(ctx, numUploadParts, limit)
go watchDelta(inner, g, &b.Completed, limit)
for i := range b.Parts { for i := range b.Parts {
part := &b.Parts[i] part := &b.Parts[i]
select { select {
case <-inner.Done(): case <-inner.Done():
break
case requestURL := <-b.nextURL: case requestURL := <-b.nextURL:
g.Go(func() error { g.Go(inner, func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts) err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)