model and file type as strings
This commit is contained in:
parent
519f4d98ef
commit
a894cc792d
4 changed files with 133 additions and 48 deletions
53
llm/ggml.go
53
llm/ggml.go
|
@ -9,8 +9,6 @@ import (
|
|||
|
||||
type ModelFamily string
|
||||
|
||||
const ModelFamilyLlama ModelFamily = "llama"
|
||||
|
||||
type ModelType uint32
|
||||
|
||||
const (
|
||||
|
@ -21,32 +19,37 @@ const (
|
|||
ModelType65B ModelType = 80
|
||||
)
|
||||
|
||||
type FileType uint32
|
||||
func (mt ModelType) String() string {
|
||||
switch mt {
|
||||
case ModelType3B:
|
||||
return "3B"
|
||||
case ModelType7B:
|
||||
return "7B"
|
||||
case ModelType13B:
|
||||
return "13B"
|
||||
case ModelType30B:
|
||||
return "30B"
|
||||
case ModelType65B:
|
||||
return "65B"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
FileTypeF32 FileType = iota
|
||||
FileTypeF16
|
||||
FileTypeQ4_0
|
||||
FileTypeQ4_1
|
||||
FileTypeQ4_1_F16
|
||||
FileTypeQ8_0 = iota + 2
|
||||
FileTypeQ5_0
|
||||
FileTypeQ5_1
|
||||
FileTypeQ2_K
|
||||
FileTypeQ3_K
|
||||
FileTypeQ4_K
|
||||
FileTypeQ5_K
|
||||
FileTypeQ6_K
|
||||
)
|
||||
type FileType interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type GGML struct {
|
||||
ModelFamily
|
||||
ModelType
|
||||
|
||||
magic uint32
|
||||
container
|
||||
model
|
||||
}
|
||||
|
||||
llamaHyperparameters
|
||||
type model interface {
|
||||
ModelFamily() ModelFamily
|
||||
ModelType() ModelType
|
||||
FileType() FileType
|
||||
}
|
||||
|
||||
type container interface {
|
||||
|
@ -166,14 +169,14 @@ func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
|
|||
// different model types may have different layouts for hyperparameters
|
||||
switch hint {
|
||||
case ModelFamilyLlama:
|
||||
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
|
||||
var llama llamaModel
|
||||
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
|
||||
ggml.model = &llama
|
||||
// TODO: sanity check hyperparameters
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported model type: %s", hint)
|
||||
}
|
||||
|
||||
// final model type
|
||||
ggml.ModelFamily = hint
|
||||
ggml.ModelType = ModelType(ggml.NumLayer)
|
||||
return &ggml, nil
|
||||
}
|
||||
|
|
104
llm/llama.go
104
llm/llama.go
|
@ -106,19 +106,22 @@ import (
|
|||
//go:embed ggml-metal.metal
|
||||
var fs embed.FS
|
||||
|
||||
type llama struct {
|
||||
params *C.struct_llama_context_params
|
||||
model *C.struct_llama_model
|
||||
ctx *C.struct_llama_context
|
||||
const ModelFamilyLlama ModelFamily = "llama"
|
||||
|
||||
last []C.llama_token
|
||||
embd []C.llama_token
|
||||
cursor int
|
||||
type llamaModel struct {
|
||||
hyperparameters llamaHyperparameters
|
||||
}
|
||||
|
||||
mu sync.Mutex
|
||||
gc bool
|
||||
func (llm *llamaModel) ModelFamily() ModelFamily {
|
||||
return ModelFamilyLlama
|
||||
}
|
||||
|
||||
api.Options
|
||||
func (llm *llamaModel) ModelType() ModelType {
|
||||
return ModelType30B
|
||||
}
|
||||
|
||||
func (llm *llamaModel) FileType() FileType {
|
||||
return llm.hyperparameters.FileType
|
||||
}
|
||||
|
||||
type llamaHyperparameters struct {
|
||||
|
@ -133,8 +136,87 @@ type llamaHyperparameters struct {
|
|||
// NumLayer is the number of layers in the model.
|
||||
NumLayer uint32
|
||||
NumRot uint32
|
||||
|
||||
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
|
||||
FileType
|
||||
FileType llamaFileType
|
||||
}
|
||||
|
||||
type llamaFileType uint32
|
||||
|
||||
const (
|
||||
llamaFileTypeF32 llamaFileType = iota
|
||||
llamaFileTypeF16
|
||||
llamaFileTypeQ4_0
|
||||
llamaFileTypeQ4_1
|
||||
llamaFileTypeQ4_1_F16
|
||||
llamaFileTypeQ8_0 llamaFileType = iota + 2
|
||||
llamaFileTypeQ5_0
|
||||
llamaFileTypeQ5_1
|
||||
llamaFileTypeQ2_K
|
||||
llamaFileTypeQ3_K_S
|
||||
llamaFileTypeQ3_K_M
|
||||
llamaFileTypeQ3_K_L
|
||||
llamaFileTypeQ4_K_S
|
||||
llamaFileTypeQ4_K_M
|
||||
llamaFileTypeQ5_K_S
|
||||
llamaFileTypeQ5_K_M
|
||||
llamaFileTypeQ6_K
|
||||
)
|
||||
|
||||
func (ft llamaFileType) String() string {
|
||||
switch ft {
|
||||
case llamaFileTypeF32:
|
||||
return "F32"
|
||||
case llamaFileTypeF16:
|
||||
return "F16"
|
||||
case llamaFileTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case llamaFileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case llamaFileTypeQ4_1_F16:
|
||||
return "Q4_1_F16"
|
||||
case llamaFileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case llamaFileTypeQ5_0:
|
||||
return "Q5_0"
|
||||
case llamaFileTypeQ5_1:
|
||||
return "Q5_1"
|
||||
case llamaFileTypeQ2_K:
|
||||
return "Q2_K"
|
||||
case llamaFileTypeQ3_K_S:
|
||||
return "Q3_K_S"
|
||||
case llamaFileTypeQ3_K_M:
|
||||
return "Q3_K_M"
|
||||
case llamaFileTypeQ3_K_L:
|
||||
return "Q3_K_L"
|
||||
case llamaFileTypeQ4_K_S:
|
||||
return "Q4_K_S"
|
||||
case llamaFileTypeQ4_K_M:
|
||||
return "Q4_K_M"
|
||||
case llamaFileTypeQ5_K_S:
|
||||
return "Q5_K_S"
|
||||
case llamaFileTypeQ5_K_M:
|
||||
return "Q5_K_M"
|
||||
case llamaFileTypeQ6_K:
|
||||
return "Q6_K"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type llama struct {
|
||||
params *C.struct_llama_context_params
|
||||
model *C.struct_llama_model
|
||||
ctx *C.struct_llama_context
|
||||
|
||||
last []C.llama_token
|
||||
embd []C.llama_token
|
||||
cursor int
|
||||
|
||||
mu sync.Mutex
|
||||
gc bool
|
||||
|
||||
api.Options
|
||||
}
|
||||
|
||||
func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {
|
||||
|
|
12
llm/llm.go
12
llm/llm.go
|
@ -35,10 +35,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
switch ggml.FileType {
|
||||
case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0:
|
||||
switch ggml.FileType().String() {
|
||||
case "F32", "F16", "Q5_0", "Q5_1", "Q8_0":
|
||||
if opts.NumGPU != 0 {
|
||||
// Q5_0, Q5_1, and Q8_0 do not support Metal API and will
|
||||
// F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
|
||||
// cause the runner to segmentation fault so disable GPU
|
||||
log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
|
||||
opts.NumGPU = 0
|
||||
|
@ -46,7 +46,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
|
|||
}
|
||||
|
||||
totalResidentMemory := memory.TotalMemory()
|
||||
switch ggml.ModelType {
|
||||
switch ggml.ModelType() {
|
||||
case ModelType3B, ModelType7B:
|
||||
if totalResidentMemory < 8*1024*1024 {
|
||||
return nil, fmt.Errorf("model requires at least 8GB of memory")
|
||||
|
@ -65,10 +65,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
|
|||
}
|
||||
}
|
||||
|
||||
switch ggml.ModelFamily {
|
||||
switch ggml.ModelFamily() {
|
||||
case ModelFamilyLlama:
|
||||
return newLlama(model, adapters, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
|
||||
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,8 +105,8 @@ type LayerReader struct {
|
|||
|
||||
type ConfigV2 struct {
|
||||
ModelFamily llm.ModelFamily `json:"model_family"`
|
||||
ModelType llm.ModelType `json:"model_type"`
|
||||
FileType llm.FileType `json:"file_type"`
|
||||
ModelType string `json:"model_type"`
|
||||
FileType string `json:"file_type"`
|
||||
RootFS RootFS `json:"rootfs"`
|
||||
|
||||
// required by spec
|
||||
|
@ -308,9 +308,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
return err
|
||||
}
|
||||
|
||||
config.ModelFamily = ggml.ModelFamily
|
||||
config.ModelType = ggml.ModelType
|
||||
config.FileType = ggml.FileType
|
||||
config.ModelFamily = ggml.ModelFamily()
|
||||
config.ModelType = ggml.ModelType().String()
|
||||
config.FileType = ggml.FileType().String()
|
||||
|
||||
// reset the file
|
||||
file.Seek(0, io.SeekStart)
|
||||
|
|
Loading…
Reference in a new issue