model and file type as strings

This commit is contained in:
Michael Yang 2023-08-17 11:37:27 -07:00
parent 519f4d98ef
commit a894cc792d
4 changed files with 133 additions and 48 deletions

View file

@ -9,8 +9,6 @@ import (
type ModelFamily string type ModelFamily string
const ModelFamilyLlama ModelFamily = "llama"
type ModelType uint32 type ModelType uint32
const ( const (
@ -21,32 +19,37 @@ const (
ModelType65B ModelType = 80 ModelType65B ModelType = 80
) )
type FileType uint32 func (mt ModelType) String() string {
switch mt {
case ModelType3B:
return "3B"
case ModelType7B:
return "7B"
case ModelType13B:
return "13B"
case ModelType30B:
return "30B"
case ModelType65B:
return "65B"
default:
return "Unknown"
}
}
const ( type FileType interface {
FileTypeF32 FileType = iota String() string
FileTypeF16 }
FileTypeQ4_0
FileTypeQ4_1
FileTypeQ4_1_F16
FileTypeQ8_0 = iota + 2
FileTypeQ5_0
FileTypeQ5_1
FileTypeQ2_K
FileTypeQ3_K
FileTypeQ4_K
FileTypeQ5_K
FileTypeQ6_K
)
type GGML struct { type GGML struct {
ModelFamily
ModelType
magic uint32 magic uint32
container container
model
}
llamaHyperparameters type model interface {
ModelFamily() ModelFamily
ModelType() ModelType
FileType() FileType
} }
type container interface { type container interface {
@ -166,14 +169,14 @@ func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
// different model types may have different layouts for hyperparameters // different model types may have different layouts for hyperparameters
switch hint { switch hint {
case ModelFamilyLlama: case ModelFamilyLlama:
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters) var llama llamaModel
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
ggml.model = &llama
// TODO: sanity check hyperparameters // TODO: sanity check hyperparameters
default: default:
return nil, fmt.Errorf("unsupported model type: %s", hint) return nil, fmt.Errorf("unsupported model type: %s", hint)
} }
// final model type // final model type
ggml.ModelFamily = hint
ggml.ModelType = ModelType(ggml.NumLayer)
return &ggml, nil return &ggml, nil
} }

View file

@ -106,19 +106,22 @@ import (
//go:embed ggml-metal.metal //go:embed ggml-metal.metal
var fs embed.FS var fs embed.FS
type llama struct { const ModelFamilyLlama ModelFamily = "llama"
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
last []C.llama_token type llamaModel struct {
embd []C.llama_token hyperparameters llamaHyperparameters
cursor int }
mu sync.Mutex func (llm *llamaModel) ModelFamily() ModelFamily {
gc bool return ModelFamilyLlama
}
api.Options func (llm *llamaModel) ModelType() ModelType {
return ModelType30B
}
func (llm *llamaModel) FileType() FileType {
return llm.hyperparameters.FileType
} }
type llamaHyperparameters struct { type llamaHyperparameters struct {
@ -133,8 +136,87 @@ type llamaHyperparameters struct {
// NumLayer is the number of layers in the model. // NumLayer is the number of layers in the model.
NumLayer uint32 NumLayer uint32
NumRot uint32 NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType FileType llamaFileType
}
type llamaFileType uint32
const (
llamaFileTypeF32 llamaFileType = iota
llamaFileTypeF16
llamaFileTypeQ4_0
llamaFileTypeQ4_1
llamaFileTypeQ4_1_F16
llamaFileTypeQ8_0 llamaFileType = iota + 2
llamaFileTypeQ5_0
llamaFileTypeQ5_1
llamaFileTypeQ2_K
llamaFileTypeQ3_K_S
llamaFileTypeQ3_K_M
llamaFileTypeQ3_K_L
llamaFileTypeQ4_K_S
llamaFileTypeQ4_K_M
llamaFileTypeQ5_K_S
llamaFileTypeQ5_K_M
llamaFileTypeQ6_K
)
func (ft llamaFileType) String() string {
switch ft {
case llamaFileTypeF32:
return "F32"
case llamaFileTypeF16:
return "F16"
case llamaFileTypeQ4_0:
return "Q4_0"
case llamaFileTypeQ4_1:
return "Q4_1"
case llamaFileTypeQ4_1_F16:
return "Q4_1_F16"
case llamaFileTypeQ8_0:
return "Q8_0"
case llamaFileTypeQ5_0:
return "Q5_0"
case llamaFileTypeQ5_1:
return "Q5_1"
case llamaFileTypeQ2_K:
return "Q2_K"
case llamaFileTypeQ3_K_S:
return "Q3_K_S"
case llamaFileTypeQ3_K_M:
return "Q3_K_M"
case llamaFileTypeQ3_K_L:
return "Q3_K_L"
case llamaFileTypeQ4_K_S:
return "Q4_K_S"
case llamaFileTypeQ4_K_M:
return "Q4_K_M"
case llamaFileTypeQ5_K_S:
return "Q5_K_S"
case llamaFileTypeQ5_K_M:
return "Q5_K_M"
case llamaFileTypeQ6_K:
return "Q6_K"
default:
return "Unknown"
}
}
type llama struct {
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
last []C.llama_token
embd []C.llama_token
cursor int
mu sync.Mutex
gc bool
api.Options
} }
func newLlama(model string, adapters []string, opts api.Options) (*llama, error) { func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {

View file

@ -35,10 +35,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
return nil, err return nil, err
} }
switch ggml.FileType { switch ggml.FileType().String() {
case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0: case "F32", "F16", "Q5_0", "Q5_1", "Q8_0":
if opts.NumGPU != 0 { if opts.NumGPU != 0 {
// Q5_0, Q5_1, and Q8_0 do not support Metal API and will // F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
// cause the runner to segmentation fault so disable GPU // cause the runner to segmentation fault so disable GPU
log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0") log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
opts.NumGPU = 0 opts.NumGPU = 0
@ -46,7 +46,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
} }
totalResidentMemory := memory.TotalMemory() totalResidentMemory := memory.TotalMemory()
switch ggml.ModelType { switch ggml.ModelType() {
case ModelType3B, ModelType7B: case ModelType3B, ModelType7B:
if totalResidentMemory < 8*1024*1024 { if totalResidentMemory < 8*1024*1024 {
return nil, fmt.Errorf("model requires at least 8GB of memory") return nil, fmt.Errorf("model requires at least 8GB of memory")
@ -65,10 +65,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
} }
} }
switch ggml.ModelFamily { switch ggml.ModelFamily() {
case ModelFamilyLlama: case ModelFamilyLlama:
return newLlama(model, adapters, opts) return newLlama(model, adapters, opts)
default: default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily) return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
} }
} }

View file

@ -105,9 +105,9 @@ type LayerReader struct {
type ConfigV2 struct { type ConfigV2 struct {
ModelFamily llm.ModelFamily `json:"model_family"` ModelFamily llm.ModelFamily `json:"model_family"`
ModelType llm.ModelType `json:"model_type"` ModelType string `json:"model_type"`
FileType llm.FileType `json:"file_type"` FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"` RootFS RootFS `json:"rootfs"`
// required by spec // required by spec
Architecture string `json:"architecture"` Architecture string `json:"architecture"`
@ -308,9 +308,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
config.ModelFamily = ggml.ModelFamily config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType config.ModelType = ggml.ModelType().String()
config.FileType = ggml.FileType config.FileType = ggml.FileType().String()
// reset the file // reset the file
file.Seek(0, io.SeekStart) file.Seek(0, io.SeekStart)