Merge pull request #760 from jmorganca/mxyng/more-downloads
Mxyng/more downloads
This commit is contained in:
commit
788637918a
5 changed files with 70 additions and 33 deletions
|
@ -127,7 +127,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * 1024 // 512KB
|
const maxBufferSize = 512 * 1000 // 512KB
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf *bytes.Buffer
|
var buf *bytes.Buffer
|
||||||
|
|
16
format/bytes.go
Normal file
16
format/bytes.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package format
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func HumanBytes(b int64) string {
|
||||||
|
switch {
|
||||||
|
case b > 1000*1000*1000:
|
||||||
|
return fmt.Sprintf("%d GB", b/1000/1000/1000)
|
||||||
|
case b > 1000*1000:
|
||||||
|
return fmt.Sprintf("%d MB", b/1000/1000)
|
||||||
|
case b > 1000:
|
||||||
|
return fmt.Sprintf("%d KB", b/1000)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d B", b)
|
||||||
|
}
|
||||||
|
}
|
|
@ -454,7 +454,7 @@ type PredictRequest struct {
|
||||||
Stop []string `json:"stop,omitempty"`
|
Stop []string `json:"stop,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * 1024 // 512KB
|
const maxBufferSize = 512 * 1000 // 512KB
|
||||||
|
|
||||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
prevConvo, err := llm.Decode(ctx, prevContext)
|
prevConvo, err := llm.Decode(ctx, prevContext)
|
||||||
|
|
20
llm/llm.go
20
llm/llm.go
|
@ -60,33 +60,33 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
|
||||||
totalResidentMemory := memory.TotalMemory()
|
totalResidentMemory := memory.TotalMemory()
|
||||||
switch ggml.ModelType() {
|
switch ggml.ModelType() {
|
||||||
case "3B", "7B":
|
case "3B", "7B":
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 16*1024*1024 {
|
if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 {
|
||||||
return nil, fmt.Errorf("F16 model requires at least 16 GB of memory")
|
return nil, fmt.Errorf("F16 model requires at least 16 GB of memory")
|
||||||
} else if totalResidentMemory < 8*1024*1024 {
|
} else if totalResidentMemory < 8*1000*1000 {
|
||||||
return nil, fmt.Errorf("model requires at least 8 GB of memory")
|
return nil, fmt.Errorf("model requires at least 8 GB of memory")
|
||||||
}
|
}
|
||||||
case "13B":
|
case "13B":
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 32*1024*1024 {
|
if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 {
|
||||||
return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
|
return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
|
||||||
} else if totalResidentMemory < 16*1024*1024 {
|
} else if totalResidentMemory < 16*1000*1000 {
|
||||||
return nil, fmt.Errorf("model requires at least 16 GB of memory")
|
return nil, fmt.Errorf("model requires at least 16 GB of memory")
|
||||||
}
|
}
|
||||||
case "30B", "34B", "40B":
|
case "30B", "34B", "40B":
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 64*1024*1024 {
|
if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 {
|
||||||
return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
|
return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
|
||||||
} else if totalResidentMemory < 32*1024*1024 {
|
} else if totalResidentMemory < 32*1000*1000 {
|
||||||
return nil, fmt.Errorf("model requires at least 32 GB of memory")
|
return nil, fmt.Errorf("model requires at least 32 GB of memory")
|
||||||
}
|
}
|
||||||
case "65B", "70B":
|
case "65B", "70B":
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 128*1024*1024 {
|
if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 {
|
||||||
return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
|
return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
|
||||||
} else if totalResidentMemory < 64*1024*1024 {
|
} else if totalResidentMemory < 64*1000*1000 {
|
||||||
return nil, fmt.Errorf("model requires at least 64 GB of memory")
|
return nil, fmt.Errorf("model requires at least 64 GB of memory")
|
||||||
}
|
}
|
||||||
case "180B":
|
case "180B":
|
||||||
if ggml.FileType() == "F16" && totalResidentMemory < 512*1024*1024 {
|
if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 {
|
||||||
return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
|
return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
|
||||||
} else if totalResidentMemory < 128*1024*1024 {
|
} else if totalResidentMemory < 128*1000*1000 {
|
||||||
return nil, fmt.Errorf("model requires at least 128GB of memory")
|
return nil, fmt.Errorf("model requires at least 128GB of memory")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
var blobDownloadManager sync.Map
|
var blobDownloadManager sync.Map
|
||||||
|
@ -34,6 +35,9 @@ type blobDownload struct {
|
||||||
Parts []*blobDownloadPart
|
Parts []*blobDownloadPart
|
||||||
|
|
||||||
context.CancelFunc
|
context.CancelFunc
|
||||||
|
|
||||||
|
done bool
|
||||||
|
err error
|
||||||
references atomic.Int32
|
references atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,6 +50,12 @@ type blobDownloadPart struct {
|
||||||
*blobDownload `json:"-"`
|
*blobDownload `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
numDownloadParts = 64
|
||||||
|
minDownloadPartSize int64 = 32 * 1000 * 1000
|
||||||
|
maxDownloadPartSize int64 = 256 * 1000 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
func (p *blobDownloadPart) Name() string {
|
func (p *blobDownloadPart) Name() string {
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||||
|
@ -91,9 +101,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
||||||
|
|
||||||
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||||
|
|
||||||
var offset int64
|
var size = b.Total / numDownloadParts
|
||||||
var size int64 = 64 * 1024 * 1024
|
switch {
|
||||||
|
case size < minDownloadPartSize:
|
||||||
|
size = minDownloadPartSize
|
||||||
|
case size > maxDownloadPartSize:
|
||||||
|
size = maxDownloadPartSize
|
||||||
|
}
|
||||||
|
|
||||||
|
var offset int64
|
||||||
for offset < b.Total {
|
for offset < b.Total {
|
||||||
if offset+size > b.Total {
|
if offset+size > b.Total {
|
||||||
size = b.Total - offset
|
size = b.Total - offset
|
||||||
|
@ -107,11 +123,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
|
log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
|
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
|
||||||
|
b.err = b.run(ctx, requestURL, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||||
defer blobDownloadManager.Delete(b.Digest)
|
defer blobDownloadManager.Delete(b.Digest)
|
||||||
|
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
@ -124,9 +144,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||||
|
|
||||||
file.Truncate(b.Total)
|
file.Truncate(b.Total)
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
g, _ := errgroup.WithContext(ctx)
|
||||||
// TODO(mxyng): download concurrency should be configurable
|
g.SetLimit(numDownloadParts)
|
||||||
g.SetLimit(64)
|
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
if part.Completed == part.Size {
|
if part.Completed == part.Size {
|
||||||
|
@ -168,7 +187,12 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.Rename(file.Name(), b.Name)
|
if err := os.Rename(file.Name(), b.Name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.done = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||||
|
@ -267,11 +291,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||||
Completed: b.Completed.Load(),
|
Completed: b.Completed.Load(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if b.Completed.Load() >= b.Total {
|
if b.done || b.err != nil {
|
||||||
// wait for the file to get renamed
|
return b.err
|
||||||
if _, err := os.Stat(b.Name); err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue