fix memory check

This commit is contained in:
Michael Yang 2023-10-12 09:34:16 -07:00
parent d790bf9916
commit 92189a5855
4 changed files with 27 additions and 18 deletions

View file

@ -14,6 +14,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@ -127,7 +128,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 512 * 1000 // 512KB const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer

View file

@ -2,14 +2,21 @@ package format
import "fmt" import "fmt"
const (
Byte = 1
KiloByte = Byte * 1000
MegaByte = KiloByte * 1000
GigaByte = MegaByte * 1000
)
func HumanBytes(b int64) string { func HumanBytes(b int64) string {
switch { switch {
case b > 1000*1000*1000: case b > GigaByte:
return fmt.Sprintf("%d GB", b/1000/1000/1000) return fmt.Sprintf("%d GB", b/GigaByte)
case b > 1000*1000: case b > MegaByte:
return fmt.Sprintf("%d MB", b/1000/1000) return fmt.Sprintf("%d MB", b/MegaByte)
case b > 1000: case b > KiloByte:
return fmt.Sprintf("%d KB", b/1000) return fmt.Sprintf("%d KB", b/KiloByte)
default: default:
return fmt.Sprintf("%d B", b) return fmt.Sprintf("%d B", b)
} }

View file

@ -509,7 +509,7 @@ type PredictRequest struct {
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
} }
const maxBufferSize = 512 * 1000 // 512KB const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext) prevConvo, err := llm.Decode(ctx, prevContext)

View file

@ -10,6 +10,7 @@ import (
"github.com/pbnjay/memory" "github.com/pbnjay/memory"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
) )
type LLM interface { type LLM interface {
@ -60,33 +61,33 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
totalResidentMemory := memory.TotalMemory() totalResidentMemory := memory.TotalMemory()
switch ggml.ModelType() { switch ggml.ModelType() {
case "3B", "7B": case "3B", "7B":
if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 { if ggml.FileType() == "F16" && totalResidentMemory < 16*format.GigaByte {
return nil, fmt.Errorf("F16 model requires at least 16 GB of memory") return nil, fmt.Errorf("F16 model requires at least 16 GB of memory")
} else if totalResidentMemory < 8*1000*1000 { } else if totalResidentMemory < 8*format.GigaByte {
return nil, fmt.Errorf("model requires at least 8 GB of memory") return nil, fmt.Errorf("model requires at least 8 GB of memory")
} }
case "13B": case "13B":
if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 { if ggml.FileType() == "F16" && totalResidentMemory < 32*format.GigaByte {
return nil, fmt.Errorf("F16 model requires at least 32 GB of memory") return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
} else if totalResidentMemory < 16*1000*1000 { } else if totalResidentMemory < 16*format.GigaByte {
return nil, fmt.Errorf("model requires at least 16 GB of memory") return nil, fmt.Errorf("model requires at least 16 GB of memory")
} }
case "30B", "34B", "40B": case "30B", "34B", "40B":
if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 { if ggml.FileType() == "F16" && totalResidentMemory < 64*format.GigaByte {
return nil, fmt.Errorf("F16 model requires at least 64 GB of memory") return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
} else if totalResidentMemory < 32*1000*1000 { } else if totalResidentMemory < 32*format.GigaByte {
return nil, fmt.Errorf("model requires at least 32 GB of memory") return nil, fmt.Errorf("model requires at least 32 GB of memory")
} }
case "65B", "70B": case "65B", "70B":
if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 { if ggml.FileType() == "F16" && totalResidentMemory < 128*format.GigaByte {
return nil, fmt.Errorf("F16 model requires at least 128 GB of memory") return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
} else if totalResidentMemory < 64*1000*1000 { } else if totalResidentMemory < 64*format.GigaByte {
return nil, fmt.Errorf("model requires at least 64 GB of memory") return nil, fmt.Errorf("model requires at least 64 GB of memory")
} }
case "180B": case "180B":
if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 { if ggml.FileType() == "F16" && totalResidentMemory < 512*format.GigaByte {
return nil, fmt.Errorf("F16 model requires at least 512GB of memory") return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
} else if totalResidentMemory < 128*1000*1000 { } else if totalResidentMemory < 128*format.GigaByte {
return nil, fmt.Errorf("model requires at least 128GB of memory") return nil, fmt.Errorf("model requires at least 128GB of memory")
} }
} }