From a894cc792de0f2270bd504dcbc3ff61e3bbc4445 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 17 Aug 2023 11:37:27 -0700 Subject: [PATCH] model and file type as strings --- llm/ggml.go | 53 ++++++++++++------------ llm/llama.go | 104 ++++++++++++++++++++++++++++++++++++++++++----- llm/llm.go | 12 +++--- server/images.go | 12 +++--- 4 files changed, 133 insertions(+), 48 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index a732f074..eca9ebbb 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -9,8 +9,6 @@ import ( type ModelFamily string -const ModelFamilyLlama ModelFamily = "llama" - type ModelType uint32 const ( @@ -21,32 +19,37 @@ const ( ModelType65B ModelType = 80 ) -type FileType uint32 +func (mt ModelType) String() string { + switch mt { + case ModelType3B: + return "3B" + case ModelType7B: + return "7B" + case ModelType13B: + return "13B" + case ModelType30B: + return "30B" + case ModelType65B: + return "65B" + default: + return "Unknown" + } +} -const ( - FileTypeF32 FileType = iota - FileTypeF16 - FileTypeQ4_0 - FileTypeQ4_1 - FileTypeQ4_1_F16 - FileTypeQ8_0 = iota + 2 - FileTypeQ5_0 - FileTypeQ5_1 - FileTypeQ2_K - FileTypeQ3_K - FileTypeQ4_K - FileTypeQ5_K - FileTypeQ6_K -) +type FileType interface { + String() string +} type GGML struct { - ModelFamily - ModelType - magic uint32 container + model +} - llamaHyperparameters +type model interface { + ModelFamily() ModelFamily + ModelType() ModelType + FileType() FileType } type container interface { @@ -166,14 +169,14 @@ func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) { // different model types may have different layouts for hyperparameters switch hint { case ModelFamilyLlama: - binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters) + var llama llamaModel + binary.Read(r, binary.LittleEndian, &llama.hyperparameters) + ggml.model = &llama // TODO: sanity check hyperparameters default: return nil, fmt.Errorf("unsupported model type: %s", hint) } // final model type - ggml.ModelFamily = hint - ggml.ModelType = ModelType(ggml.NumLayer) return &ggml, nil } diff --git a/llm/llama.go b/llm/llama.go index 2a310cac..8c5762b6 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -106,19 +106,22 @@ import ( //go:embed ggml-metal.metal var fs embed.FS -type llama struct { - params *C.struct_llama_context_params - model *C.struct_llama_model - ctx *C.struct_llama_context +const ModelFamilyLlama ModelFamily = "llama" - last []C.llama_token - embd []C.llama_token - cursor int +type llamaModel struct { + hyperparameters llamaHyperparameters +} - mu sync.Mutex - gc bool +func (llm *llamaModel) ModelFamily() ModelFamily { + return ModelFamilyLlama +} - api.Options +func (llm *llamaModel) ModelType() ModelType { + return ModelType30B +} + +func (llm *llamaModel) FileType() FileType { + return llm.hyperparameters.FileType } type llamaHyperparameters struct { @@ -133,8 +136,87 @@ type llamaHyperparameters struct { // NumLayer is the number of layers in the model. NumLayer uint32 NumRot uint32 + // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc. - FileType + FileType llamaFileType +} + +type llamaFileType uint32 + +const ( + llamaFileTypeF32 llamaFileType = iota + llamaFileTypeF16 + llamaFileTypeQ4_0 + llamaFileTypeQ4_1 + llamaFileTypeQ4_1_F16 + llamaFileTypeQ8_0 llamaFileType = iota + 2 + llamaFileTypeQ5_0 + llamaFileTypeQ5_1 + llamaFileTypeQ2_K + llamaFileTypeQ3_K_S + llamaFileTypeQ3_K_M + llamaFileTypeQ3_K_L + llamaFileTypeQ4_K_S + llamaFileTypeQ4_K_M + llamaFileTypeQ5_K_S + llamaFileTypeQ5_K_M + llamaFileTypeQ6_K +) + +func (ft llamaFileType) String() string { + switch ft { + case llamaFileTypeF32: + return "F32" + case llamaFileTypeF16: + return "F16" + case llamaFileTypeQ4_0: + return "Q4_0" + case llamaFileTypeQ4_1: + return "Q4_1" + case llamaFileTypeQ4_1_F16: + return "Q4_1_F16" + case llamaFileTypeQ8_0: + return "Q8_0" + case llamaFileTypeQ5_0: + return "Q5_0" + case llamaFileTypeQ5_1: + return "Q5_1" + case llamaFileTypeQ2_K: + return "Q2_K" + case llamaFileTypeQ3_K_S: + return "Q3_K_S" + case llamaFileTypeQ3_K_M: + return "Q3_K_M" + case llamaFileTypeQ3_K_L: + return "Q3_K_L" + case llamaFileTypeQ4_K_S: + return "Q4_K_S" + case llamaFileTypeQ4_K_M: + return "Q4_K_M" + case llamaFileTypeQ5_K_S: + return "Q5_K_S" + case llamaFileTypeQ5_K_M: + return "Q5_K_M" + case llamaFileTypeQ6_K: + return "Q6_K" + default: + return "Unknown" + } +} + +type llama struct { + params *C.struct_llama_context_params + model *C.struct_llama_model + ctx *C.struct_llama_context + + last []C.llama_token + embd []C.llama_token + cursor int + + mu sync.Mutex + gc bool + + api.Options } func newLlama(model string, adapters []string, opts api.Options) (*llama, error) { diff --git a/llm/llm.go b/llm/llm.go index edc1107d..db4dcbe7 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -35,10 +35,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) { return nil, err } - switch ggml.FileType { - case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0: + switch ggml.FileType().String() { + case "F32", "F16", "Q5_0", "Q5_1", "Q8_0": if opts.NumGPU != 0 { - // Q5_0, Q5_1, and Q8_0 do not support Metal API and will + // F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will // cause the runner to segmentation fault so disable GPU log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0") opts.NumGPU = 0 @@ -46,7 +46,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) { } totalResidentMemory := memory.TotalMemory() - switch ggml.ModelType { + switch ggml.ModelType() { case ModelType3B, ModelType7B: if totalResidentMemory < 8*1024*1024 { return nil, fmt.Errorf("model requires at least 8GB of memory") @@ -65,10 +65,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) { } } - switch ggml.ModelFamily { + switch ggml.ModelFamily() { case ModelFamilyLlama: return newLlama(model, adapters, opts) default: - return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily) + return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily()) } } diff --git a/server/images.go b/server/images.go index bdf166de..df1b26fd 100644 --- a/server/images.go +++ b/server/images.go @@ -105,9 +105,9 @@ type LayerReader struct { type ConfigV2 struct { ModelFamily llm.ModelFamily `json:"model_family"` - ModelType llm.ModelType `json:"model_type"` - FileType llm.FileType `json:"file_type"` - RootFS RootFS `json:"rootfs"` + ModelType string `json:"model_type"` + FileType string `json:"file_type"` + RootFS RootFS `json:"rootfs"` // required by spec Architecture string `json:"architecture"` @@ -308,9 +308,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api return err } - config.ModelFamily = ggml.ModelFamily - config.ModelType = ggml.ModelType - config.FileType = ggml.FileType + config.ModelFamily = ggml.ModelFamily() + config.ModelType = ggml.ModelType().String() + config.FileType = ggml.FileType().String() // reset the file file.Seek(0, io.SeekStart)