diff --git a/llm/ggml.go b/llm/ggml.go index ff9380f1..4c1c673e 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -330,6 +330,8 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui headsKV := llm.KV().HeadCountKV() vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any))) + layers := llm.Tensors().Layers() + switch llm.KV().Architecture() { case "llama": fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads)) @@ -339,6 +341,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV), 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) + + if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok { + ffnGateWeight1 := ffnGateWeight.Shape[1] + fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1) + partialOffload = max( + 4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16, + 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16), + ) + } case "gemma": fullOffload = 4 * batch * (embedding + vocab) partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128 @@ -350,7 +361,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui partialOffload = max( 4*batch*(embedding+vocab)+embedding*vocab*105/128, - 4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16, + 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16, ) case "qwen2": fullOffload = max(