unbound max num gpu layers (#591)

---------

Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
Bruce MacDonald 2023-09-25 23:36:46 +01:00 committed by GitHub
parent b934bf23e6
commit 86279f4ae3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 29 deletions

View file

@ -77,6 +77,7 @@ type model interface {
ModelFamily() string
ModelType() string
FileType() string
NumLayers() int64
}
type container interface {

View file

@ -195,6 +195,16 @@ func (llm *ggufModel) Decode(r io.Reader) error {
return nil
}
func (llm *ggufModel) NumLayers() int64 {
value, exists := llm.kv[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
if !exists {
return 0
}
v := value.(uint32)
return int64(v)
}
func (ggufModel) readU8(r io.Reader) uint8 {
var u8 uint8
binary.Read(r, binary.LittleEndian, &u8)

View file

@ -152,6 +152,10 @@ func (llm *llamaModel) FileType() string {
return fileType(llm.hyperparameters.FileType)
}
func (llm *llamaModel) NumLayers() int64 {
return int64(llm.hyperparameters.NumLayer)
}
type llamaHyperparameters struct {
// NumVocab is the size of the model's vocabulary.
NumVocab uint32
@ -207,13 +211,13 @@ func CheckVRAM() (int, error) {
return total, nil
}
func NumGPU(opts api.Options) int {
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
if opts.NumGPU != -1 {
return opts.NumGPU
}
n := 1 // default to enable metal on macOS
if runtime.GOOS == "linux" {
vram, err := CheckVRAM()
vramMib, err := CheckVRAM()
if err != nil {
if err.Error() != "nvidia-smi command failed" {
log.Print(err.Error())
@ -221,33 +225,25 @@ func NumGPU(opts api.Options) int {
// nvidia driver not installed or no nvidia GPU found
return 0
}
// TODO: this is a very rough heuristic, better would be to calculate this based on number of layers and context size
switch {
case vram < 500:
log.Printf("WARNING: Low VRAM detected, disabling GPU")
n = 0
case vram < 1000:
n = 4
case vram < 2000:
n = 8
case vram < 4000:
n = 12
case vram < 8000:
n = 16
case vram < 12000:
n = 24
case vram < 16000:
n = 32
default:
n = 48
totalVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
// Calculate bytes per layer
// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
bytesPerLayer := fileSizeBytes / numLayer
// set n to the max number of layers we can fit in VRAM
return int(totalVramBytes / bytesPerLayer)
log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, n)
}
log.Printf("%d MB VRAM available, loading %d GPU layers", vram, n)
}
return n
// default to enable metal on macOS
return 1
}
func newLlama(model string, adapters []string, runners []ModelRunner, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil {
func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
fileInfo, err := os.Stat(model)
if err != nil {
return nil, err
}
@ -261,7 +257,7 @@ func newLlama(model string, adapters []string, runners []ModelRunner, opts api.O
"--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
"--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--n-gpu-layers", fmt.Sprintf("%d", NumGPU(opts)),
"--n-gpu-layers", fmt.Sprintf("%d", NumGPU(numLayers, fileInfo.Size(), opts)),
"--embedding",
}

View file

@ -91,9 +91,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
switch ggml.Name() {
case "gguf":
opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), opts)
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), opts)
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
}