From 0c5a454361c57f300254971b109e5e4ec937ebd3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 12 Sep 2023 10:52:57 -0700 Subject: [PATCH] fix model type for 70b --- llm/gguf.go | 6 ++++++ server/images.go | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/llm/gguf.go b/llm/gguf.go index 047d17cf..7680c90c 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -99,6 +99,12 @@ func (llm *ggufModel) ModelType() string { switch llm.ModelFamily() { case "llama": if blocks, ok := llm.kv["llama.block_count"].(uint32); ok { + heads, headsOK := llm.kv["llama.head_count"].(uint32) + headKVs, headsKVsOK := llm.kv["llama.head_count_kv"].(uint32) + if headsOK && headsKVsOK && heads/headKVs == 8 { + return "70B" + } + return llamaModelType(blocks) } case "falcon": diff --git a/server/images.go b/server/images.go index 01ec4306..faf0205d 100644 --- a/server/images.go +++ b/server/images.go @@ -498,6 +498,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } } + if config.ModelType == "65B" { + if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { + config.ModelType = "70B" + } + } + bts, err := json.Marshal(formattedParams) if err != nil { return err @@ -815,14 +821,14 @@ func formatParams(params map[string][]string) (map[string]interface{}, error) { return nil, fmt.Errorf("invalid float value %s", vals) } - out[key] = floatVal + out[key] = float32(floatVal) case reflect.Int: - intVal, err := strconv.ParseInt(vals[0], 10, 0) + intVal, err := strconv.ParseInt(vals[0], 10, 64) if err != nil { return nil, fmt.Errorf("invalid int value %s", vals) } - out[key] = intVal + out[key] = int(intVal) case reflect.Bool: boolVal, err := strconv.ParseBool(vals[0]) if err != nil {