diff --git a/go.mod b/go.mod index a0583e65..d4a460b9 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 7ec060d3..37f5cea0 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= +github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= diff --git a/llm/llm.go b/llm/llm.go index b537865e..e56ea24f 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -2,8 +2,11 @@ package llm import ( "fmt" + "log" "os" + "github.com/pbnjay/memory" + "github.com/jmorganca/ollama/api" ) @@ -31,6 +34,36 @@ func New(model string, opts api.Options) (LLM, error) { return nil, err } + switch ggml.FileType { + case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0: + if opts.NumGPU != 0 { + // Q5_0, Q5_1, and Q8_0 do not support Metal API and will + // cause the runner to segmentation fault so disable GPU + log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0") + opts.NumGPU = 0 + } + } + + totalResidentMemory := memory.TotalMemory() + switch ggml.ModelType { + case ModelType3B, ModelType7B: + if totalResidentMemory < 8*1024*1024 { + return nil, fmt.Errorf("model requires at least 8GB of memory") + } + case ModelType13B: + if totalResidentMemory < 16*1024*1024 { + return nil, fmt.Errorf("model requires at least 16GB of memory") + } + case ModelType30B: + if totalResidentMemory < 32*1024*1024 { + return nil, fmt.Errorf("model requires at least 32GB of memory") + } + case ModelType65B: + if totalResidentMemory < 64*1024*1024 { + return nil, fmt.Errorf("model requires at least 64GB of memory") + } + } + switch ggml.ModelFamily { case ModelFamilyLlama: return newLlama(model, opts)