Merge pull request #372 from jmorganca/mxyng/string-types

model and file type as strings
This commit is contained in:
Michael Yang 2023-08-17 15:10:59 -07:00 committed by GitHub
commit cabaada956
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 48 deletions

View file

@ -9,8 +9,6 @@ import (
type ModelFamily string
const ModelFamilyLlama ModelFamily = "llama"
type ModelType uint32
const (
@ -21,32 +19,37 @@ const (
ModelType65B ModelType = 80
)
type FileType uint32
func (mt ModelType) String() string {
switch mt {
case ModelType3B:
return "3B"
case ModelType7B:
return "7B"
case ModelType13B:
return "13B"
case ModelType30B:
return "30B"
case ModelType65B:
return "65B"
default:
return "Unknown"
}
}
const (
FileTypeF32 FileType = iota
FileTypeF16
FileTypeQ4_0
FileTypeQ4_1
FileTypeQ4_1_F16
FileTypeQ8_0 = iota + 2
FileTypeQ5_0
FileTypeQ5_1
FileTypeQ2_K
FileTypeQ3_K
FileTypeQ4_K
FileTypeQ5_K
FileTypeQ6_K
)
type FileType interface {
String() string
}
type GGML struct {
ModelFamily
ModelType
magic uint32
container
model
}
llamaHyperparameters
type model interface {
ModelFamily() ModelFamily
ModelType() ModelType
FileType() FileType
}
type container interface {
@ -166,14 +169,14 @@ func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
// different model types may have different layouts for hyperparameters
switch hint {
case ModelFamilyLlama:
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
var llama llamaModel
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
ggml.model = &llama
// TODO: sanity check hyperparameters
default:
return nil, fmt.Errorf("unsupported model type: %s", hint)
}
// final model type
ggml.ModelFamily = hint
ggml.ModelType = ModelType(ggml.NumLayer)
return &ggml, nil
}

View file

@ -106,19 +106,22 @@ import (
//go:embed ggml-metal.metal
var fs embed.FS
type llama struct {
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
const ModelFamilyLlama ModelFamily = "llama"
last []C.llama_token
embd []C.llama_token
cursor int
type llamaModel struct {
hyperparameters llamaHyperparameters
}
mu sync.Mutex
gc bool
func (llm *llamaModel) ModelFamily() ModelFamily {
return ModelFamilyLlama
}
api.Options
func (llm *llamaModel) ModelType() ModelType {
return ModelType30B
}
func (llm *llamaModel) FileType() FileType {
return llm.hyperparameters.FileType
}
type llamaHyperparameters struct {
@ -133,8 +136,87 @@ type llamaHyperparameters struct {
// NumLayer is the number of layers in the model.
NumLayer uint32
NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType
FileType llamaFileType
}
type llamaFileType uint32
const (
llamaFileTypeF32 llamaFileType = iota
llamaFileTypeF16
llamaFileTypeQ4_0
llamaFileTypeQ4_1
llamaFileTypeQ4_1_F16
llamaFileTypeQ8_0 llamaFileType = iota + 2
llamaFileTypeQ5_0
llamaFileTypeQ5_1
llamaFileTypeQ2_K
llamaFileTypeQ3_K_S
llamaFileTypeQ3_K_M
llamaFileTypeQ3_K_L
llamaFileTypeQ4_K_S
llamaFileTypeQ4_K_M
llamaFileTypeQ5_K_S
llamaFileTypeQ5_K_M
llamaFileTypeQ6_K
)
func (ft llamaFileType) String() string {
switch ft {
case llamaFileTypeF32:
return "F32"
case llamaFileTypeF16:
return "F16"
case llamaFileTypeQ4_0:
return "Q4_0"
case llamaFileTypeQ4_1:
return "Q4_1"
case llamaFileTypeQ4_1_F16:
return "Q4_1_F16"
case llamaFileTypeQ8_0:
return "Q8_0"
case llamaFileTypeQ5_0:
return "Q5_0"
case llamaFileTypeQ5_1:
return "Q5_1"
case llamaFileTypeQ2_K:
return "Q2_K"
case llamaFileTypeQ3_K_S:
return "Q3_K_S"
case llamaFileTypeQ3_K_M:
return "Q3_K_M"
case llamaFileTypeQ3_K_L:
return "Q3_K_L"
case llamaFileTypeQ4_K_S:
return "Q4_K_S"
case llamaFileTypeQ4_K_M:
return "Q4_K_M"
case llamaFileTypeQ5_K_S:
return "Q5_K_S"
case llamaFileTypeQ5_K_M:
return "Q5_K_M"
case llamaFileTypeQ6_K:
return "Q6_K"
default:
return "Unknown"
}
}
type llama struct {
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
last []C.llama_token
embd []C.llama_token
cursor int
mu sync.Mutex
gc bool
api.Options
}
func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {

View file

@ -35,10 +35,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
return nil, err
}
switch ggml.FileType {
case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0:
switch ggml.FileType().String() {
case "F32", "F16", "Q5_0", "Q5_1", "Q8_0":
if opts.NumGPU != 0 {
// Q5_0, Q5_1, and Q8_0 do not support Metal API and will
// F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
// cause the runner to segmentation fault so disable GPU
log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
opts.NumGPU = 0
@ -46,7 +46,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
}
totalResidentMemory := memory.TotalMemory()
switch ggml.ModelType {
switch ggml.ModelType() {
case ModelType3B, ModelType7B:
if totalResidentMemory < 8*1024*1024 {
return nil, fmt.Errorf("model requires at least 8GB of memory")
@ -65,10 +65,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
}
}
switch ggml.ModelFamily {
switch ggml.ModelFamily() {
case ModelFamilyLlama:
return newLlama(model, adapters, opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
}
}

View file

@ -105,8 +105,8 @@ type LayerReader struct {
type ConfigV2 struct {
ModelFamily llm.ModelFamily `json:"model_family"`
ModelType llm.ModelType `json:"model_type"`
FileType llm.FileType `json:"file_type"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"`
// required by spec
@ -308,9 +308,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err
}
config.ModelFamily = ggml.ModelFamily
config.ModelType = ggml.ModelType
config.FileType = ggml.FileType
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType().String()
config.FileType = ggml.FileType().String()
// reset the file
file.Seek(0, io.SeekStart)