From 435cc866a3fbabb5029b8a2496631847a871616f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 22 Apr 2024 16:57:05 -0700 Subject: [PATCH] fix: mixtral graph --- llm/ggml.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llm/ggml.go b/llm/ggml.go index f40f17e5..1b094027 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -343,7 +343,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) - if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok { + if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok { + // mixtral 8x22b + ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32)) + partialOffload = max( + 3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV), + 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch), + ) + } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok { + // mixtral 8x7b ffnGateWeight1 := ffnGateWeight.Shape[1] fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1) partialOffload = max(