fix model type for 70b

This commit is contained in:
Michael Yang 2023-09-12 10:52:57 -07:00
parent 7dee25a07f
commit 0c5a454361
2 changed files with 15 additions and 3 deletions

View file

@ -99,6 +99,12 @@ func (llm *ggufModel) ModelType() string {
switch llm.ModelFamily() { switch llm.ModelFamily() {
case "llama": case "llama":
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok { if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
heads, headsOK := llm.kv["llama.head_count"].(uint32)
headKVs, headsKVsOK := llm.kv["llama.head_count_kv"].(uint32)
if headsOK && headsKVsOK && heads/headKVs == 8 {
return "70B"
}
return llamaModelType(blocks) return llamaModelType(blocks)
} }
case "falcon": case "falcon":

View file

@ -498,6 +498,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
} }
if config.ModelType == "65B" {
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
config.ModelType = "70B"
}
}
bts, err := json.Marshal(formattedParams) bts, err := json.Marshal(formattedParams)
if err != nil { if err != nil {
return err return err
@ -815,14 +821,14 @@ func formatParams(params map[string][]string) (map[string]interface{}, error) {
return nil, fmt.Errorf("invalid float value %s", vals) return nil, fmt.Errorf("invalid float value %s", vals)
} }
out[key] = floatVal out[key] = float32(floatVal)
case reflect.Int: case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 0) intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals) return nil, fmt.Errorf("invalid int value %s", vals)
} }
out[key] = intVal out[key] = int(intVal)
case reflect.Bool: case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0]) boolVal, err := strconv.ParseBool(vals[0])
if err != nil { if err != nil {