add new LimitGroup for dynamic concurrency

This commit is contained in:
Michael Yang 2024-01-26 13:50:03 -08:00
parent 917bd61084
commit 0de12368a0

View file

@ -20,6 +20,7 @@ import (
"time" "time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
@ -150,8 +151,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
g, inner := errgroup.WithContext(ctx) g, inner := NewLimitGroup(ctx, numDownloadParts)
g.SetLimit(numDownloadParts)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed == part.Size { if part.Completed == part.Size {
@ -378,3 +378,41 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
return download.Wait(ctx, opts.fn) return download.Wait(ctx, opts.fn)
} }
type LimitGroup struct {
*errgroup.Group
context.Context
Semaphore *semaphore.Weighted
weight, max_weight int64
}
func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) {
g, ctx := errgroup.WithContext(ctx)
return &LimitGroup{
Group: g,
Context: ctx,
Semaphore: semaphore.NewWeighted(n),
weight: n,
max_weight: n,
}, ctx
}
func (g *LimitGroup) Go(fn func() error) {
weight := g.weight
g.Semaphore.Acquire(g.Context, weight)
if g.Context.Err() != nil {
return
}
g.Group.Go(func() error {
defer g.Semaphore.Release(weight)
return fn()
})
}
func (g *LimitGroup) SetLimit(n int64) {
if n > 0 {
g.weight = g.max_weight / n
}
}