fix model type for 70b
This commit is contained in:
parent
7dee25a07f
commit
0c5a454361
2 changed files with 15 additions and 3 deletions
|
@ -99,6 +99,12 @@ func (llm *ggufModel) ModelType() string {
|
||||||
switch llm.ModelFamily() {
|
switch llm.ModelFamily() {
|
||||||
case "llama":
|
case "llama":
|
||||||
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
|
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
|
||||||
|
heads, headsOK := llm.kv["llama.head_count"].(uint32)
|
||||||
|
headKVs, headsKVsOK := llm.kv["llama.head_count_kv"].(uint32)
|
||||||
|
if headsOK && headsKVsOK && heads/headKVs == 8 {
|
||||||
|
return "70B"
|
||||||
|
}
|
||||||
|
|
||||||
return llamaModelType(blocks)
|
return llamaModelType(blocks)
|
||||||
}
|
}
|
||||||
case "falcon":
|
case "falcon":
|
||||||
|
|
|
@ -498,6 +498,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.ModelType == "65B" {
|
||||||
|
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
|
||||||
|
config.ModelType = "70B"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bts, err := json.Marshal(formattedParams)
|
bts, err := json.Marshal(formattedParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -815,14 +821,14 @@ func formatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
return nil, fmt.Errorf("invalid float value %s", vals)
|
return nil, fmt.Errorf("invalid float value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
out[key] = floatVal
|
out[key] = float32(floatVal)
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
intVal, err := strconv.ParseInt(vals[0], 10, 0)
|
intVal, err := strconv.ParseInt(vals[0], 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid int value %s", vals)
|
return nil, fmt.Errorf("invalid int value %s", vals)
|
||||||
}
|
}
|
||||||
|
|
||||||
out[key] = intVal
|
out[key] = int(intVal)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
boolVal, err := strconv.ParseBool(vals[0])
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue