diff --git a/go.mod b/go.mod index a0583e65..d4a460b9 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 7ec060d3..37f5cea0 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= +github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= diff --git a/llm/llm.go b/llm/llm.go index 020e3c2f..e56ea24f 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -5,6 +5,8 @@ import ( "log" "os" + "github.com/pbnjay/memory" + "github.com/jmorganca/ollama/api" ) @@ -42,6 +44,26 @@ func New(model string, opts api.Options) (LLM, error) { } } + totalResidentMemory := memory.TotalMemory() + switch ggml.ModelType { + case ModelType3B, ModelType7B: + if totalResidentMemory < 8*1024*1024 { + return nil, fmt.Errorf("model requires at least 8GB of memory") + } + case ModelType13B: + if totalResidentMemory < 16*1024*1024 { + return nil, fmt.Errorf("model requires at least 16GB of memory") + } + case ModelType30B: + if totalResidentMemory < 32*1024*1024 { + return nil, fmt.Errorf("model requires at least 32GB of memory") + } + case ModelType65B: + if totalResidentMemory < 64*1024*1024 { + return nil, fmt.Errorf("model requires at least 64GB of memory") + } + } + switch ggml.ModelFamily { case ModelFamilyLlama: return newLlama(model, opts)