diff --git a/llm/ggml.go b/llm/ggml.go index 4b73f510..d877acd1 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -148,15 +148,15 @@ func (kv KV) HeadCount() uint64 { } func (kv KV) HeadCountKV() uint64 { - return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())) + if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 { + return headCountKV + } + + return 1 } func (kv KV) GQA() uint64 { - if headCountKV := kv.HeadCountKV(); headCountKV > 0 { - return kv.HeadCount() / headCountKV - } - - return 0 + return kv.HeadCount() / kv.HeadCountKV() } func (kv KV) EmbeddingLength() uint64 {