fix falcon decode

get model and file type from bin file
This commit is contained in:
Michael Yang 2023-09-12 10:01:20 -07:00
parent f221637053
commit 7dee25a07f
5 changed files with 123 additions and 158 deletions

View file

@ -8,54 +8,77 @@ import (
"sync" "sync"
) )
type ModelFamily string
const ModelFamilyUnknown ModelFamily = "unknown"
type ModelType uint32
const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType34B ModelType = 48
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
func (mt ModelType) String() string {
switch mt {
case ModelType3B:
return "3B"
case ModelType7B:
return "7B"
case ModelType13B:
return "13B"
case ModelType34B:
return "34B"
case ModelType30B:
return "30B"
case ModelType65B:
return "65B"
default:
return "Unknown"
}
}
type FileType interface {
String() string
}
type GGML struct { type GGML struct {
magic uint32 magic uint32
container container
model model
} }
const (
fileTypeF32 uint32 = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ8_0 uint32 = iota + 2
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
)
func fileType(fileType uint32) string {
switch fileType {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
default:
return "Unknown"
}
}
type model interface { type model interface {
ModelFamily() ModelFamily ModelFamily() string
ModelType() ModelType ModelType() string
FileType() FileType FileType() string
} }
type container interface { type container interface {

View file

@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"path" "path"
"sync" "sync"
) )
@ -87,38 +86,37 @@ func (llm *ggufModel) NumKV() uint64 {
return llm.V2.NumKV return llm.V2.NumKV
} }
func (llm *ggufModel) ModelFamily() ModelFamily { func (llm *ggufModel) ModelFamily() string {
t, ok := llm.kv["general.architecture"].(string) t, ok := llm.kv["general.architecture"].(string)
if ok { if ok {
return ModelFamily(t) return t
} }
log.Printf("unknown model family: %T", t) return "unknown"
return ModelFamilyUnknown
} }
func (llm *ggufModel) ModelType() ModelType { func (llm *ggufModel) ModelType() string {
switch llm.ModelFamily() { switch llm.ModelFamily() {
case ModelFamilyLlama: case "llama":
blocks, ok := llm.kv["llama.block_count"].(uint32) if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
if ok { return llamaModelType(blocks)
return ModelType(blocks) }
case "falcon":
if blocks, ok := llm.kv["falcon.block_count"].(uint32); ok {
return falconModelType(blocks)
} }
} }
return ModelType7B return "Unknown"
} }
func (llm *ggufModel) FileType() FileType { func (llm *ggufModel) FileType() string {
switch llm.ModelFamily() { t, ok := llm.kv["general.file_type"].(uint32)
case ModelFamilyLlama: if ok {
t, ok := llm.kv["general.file_type"].(uint32) return fileType(t)
if ok {
return llamaFileType(t)
}
} }
return llamaFileTypeF16 return "Unknown"
} }
func (llm *ggufModel) Decode(r io.Reader) error { func (llm *ggufModel) Decode(r io.Reader) error {

View file

@ -95,38 +95,39 @@ func chooseRunner(gpuPath, cpuPath string) string {
return runPath return runPath
} }
const ModelFamilyLlama ModelFamily = "llama"
type llamaModel struct { type llamaModel struct {
hyperparameters llamaHyperparameters hyperparameters llamaHyperparameters
} }
func (llm *llamaModel) ModelFamily() ModelFamily { func (llm *llamaModel) ModelFamily() string {
return ModelFamilyLlama return "llama"
} }
func (llm *llamaModel) ModelType() ModelType { func llamaModelType(numLayer uint32) string {
switch llm.hyperparameters.NumLayer { switch numLayer {
case 26: case 26:
return ModelType3B return "3B"
case 32: case 32:
return ModelType7B return "7B"
case 40: case 40:
return ModelType13B return "13B"
case 48: case 48:
return ModelType34B return "34B"
case 60: case 60:
return ModelType30B return "30B"
case 80: case 80:
return ModelType65B return "65B"
default:
return "Unknown"
} }
// TODO: find a better default
return ModelType7B
} }
func (llm *llamaModel) FileType() FileType { func (llm *llamaModel) ModelType() string {
return llm.hyperparameters.FileType return llamaModelType(llm.hyperparameters.NumLayer)
}
func (llm *llamaModel) FileType() string {
return fileType(llm.hyperparameters.FileType)
} }
type llamaHyperparameters struct { type llamaHyperparameters struct {
@ -143,70 +144,7 @@ type llamaHyperparameters struct {
NumRot uint32 NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType llamaFileType FileType uint32
}
type llamaFileType uint32
const (
llamaFileTypeF32 llamaFileType = iota
llamaFileTypeF16
llamaFileTypeQ4_0
llamaFileTypeQ4_1
llamaFileTypeQ4_1_F16
llamaFileTypeQ8_0 llamaFileType = iota + 2
llamaFileTypeQ5_0
llamaFileTypeQ5_1
llamaFileTypeQ2_K
llamaFileTypeQ3_K_S
llamaFileTypeQ3_K_M
llamaFileTypeQ3_K_L
llamaFileTypeQ4_K_S
llamaFileTypeQ4_K_M
llamaFileTypeQ5_K_S
llamaFileTypeQ5_K_M
llamaFileTypeQ6_K
)
func (ft llamaFileType) String() string {
switch ft {
case llamaFileTypeF32:
return "F32"
case llamaFileTypeF16:
return "F16"
case llamaFileTypeQ4_0:
return "Q4_0"
case llamaFileTypeQ4_1:
return "Q4_1"
case llamaFileTypeQ4_1_F16:
return "Q4_1_F16"
case llamaFileTypeQ8_0:
return "Q8_0"
case llamaFileTypeQ5_0:
return "Q5_0"
case llamaFileTypeQ5_1:
return "Q5_1"
case llamaFileTypeQ2_K:
return "Q2_K"
case llamaFileTypeQ3_K_S:
return "Q3_K_S"
case llamaFileTypeQ3_K_M:
return "Q3_K_M"
case llamaFileTypeQ3_K_L:
return "Q3_K_L"
case llamaFileTypeQ4_K_S:
return "Q4_K_S"
case llamaFileTypeQ4_K_M:
return "Q4_K_M"
case llamaFileTypeQ5_K_S:
return "Q5_K_S"
case llamaFileTypeQ5_K_M:
return "Q5_K_M"
case llamaFileTypeQ6_K:
return "Q6_K"
default:
return "Unknown"
}
} }
type Running struct { type Running struct {

View file

@ -37,7 +37,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
return nil, err return nil, err
} }
switch ggml.FileType().String() { switch ggml.FileType() {
case "Q8_0": case "Q8_0":
if ggml.Name() != "gguf" && opts.NumGPU != 0 { if ggml.Name() != "gguf" && opts.NumGPU != 0 {
// GGML Q8_0 do not support Metal API and will // GGML Q8_0 do not support Metal API and will
@ -56,30 +56,36 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
totalResidentMemory := memory.TotalMemory() totalResidentMemory := memory.TotalMemory()
switch ggml.ModelType() { switch ggml.ModelType() {
case ModelType3B, ModelType7B: case "3B", "7B":
if ggml.FileType().String() == "F16" && totalResidentMemory < 16*1024*1024 { if ggml.FileType() == "F16" && totalResidentMemory < 16*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 16GB of memory") return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
} else if totalResidentMemory < 8*1024*1024 { } else if totalResidentMemory < 8*1024*1024 {
return nil, fmt.Errorf("model requires at least 8GB of memory") return nil, fmt.Errorf("model requires at least 8GB of memory")
} }
case ModelType13B: case "13B":
if ggml.FileType().String() == "F16" && totalResidentMemory < 32*1024*1024 { if ggml.FileType() == "F16" && totalResidentMemory < 32*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 32GB of memory") return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
} else if totalResidentMemory < 16*1024*1024 { } else if totalResidentMemory < 16*1024*1024 {
return nil, fmt.Errorf("model requires at least 16GB of memory") return nil, fmt.Errorf("model requires at least 16GB of memory")
} }
case ModelType30B, ModelType34B: case "30B", "34B", "40B":
if ggml.FileType().String() == "F16" && totalResidentMemory < 64*1024*1024 { if ggml.FileType() == "F16" && totalResidentMemory < 64*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 64GB of memory") return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
} else if totalResidentMemory < 32*1024*1024 { } else if totalResidentMemory < 32*1024*1024 {
return nil, fmt.Errorf("model requires at least 32GB of memory") return nil, fmt.Errorf("model requires at least 32GB of memory")
} }
case ModelType65B: case "65B", "70B":
if ggml.FileType().String() == "F16" && totalResidentMemory < 128*1024*1024 { if ggml.FileType() == "F16" && totalResidentMemory < 128*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 128GB of memory") return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
} else if totalResidentMemory < 64*1024*1024 { } else if totalResidentMemory < 64*1024*1024 {
return nil, fmt.Errorf("model requires at least 64GB of memory") return nil, fmt.Errorf("model requires at least 64GB of memory")
} }
case "180B":
if ggml.FileType() == "F16" && totalResidentMemory < 512*1024*1024 {
return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
} else if totalResidentMemory < 128*1024*1024 {
return nil, fmt.Errorf("model requires at least 128GB of memory")
}
} }
switch ggml.Name() { switch ggml.Name() {

View file

@ -114,11 +114,11 @@ type LayerReader struct {
} }
type ConfigV2 struct { type ConfigV2 struct {
ModelFamily llm.ModelFamily `json:"model_family"` ModelFormat string `json:"model_format"`
ModelType string `json:"model_type"` ModelFamily string `json:"model_family"`
ModelFormat string `json:"model_format"` ModelType string `json:"model_type"`
FileType string `json:"file_type"` FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"` RootFS RootFS `json:"rootfs"`
// required by spec // required by spec
Architecture string `json:"architecture"` Architecture string `json:"architecture"`
@ -357,10 +357,10 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType().String()
config.ModelFormat = ggml.Name() config.ModelFormat = ggml.Name()
config.FileType = ggml.FileType().String() config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
// reset the file // reset the file
file.Seek(0, io.SeekStart) file.Seek(0, io.SeekStart)