mixtral mem

This commit is contained in:
Michael Yang 2024-04-11 10:26:35 -07:00
parent 5a25f93522
commit 3397eff0cd

View file

@ -330,6 +330,8 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
headsKV := llm.KV().HeadCountKV() headsKV := llm.KV().HeadCountKV()
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any))) vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
layers := llm.Tensors().Layers()
switch llm.KV().Architecture() { switch llm.KV().Architecture() {
case "llama": case "llama":
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads)) fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
@ -339,6 +341,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV), 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
ffnGateWeight1 := ffnGateWeight.Shape[1]
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
partialOffload = max(
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
)
}
case "gemma": case "gemma":
fullOffload = 4 * batch * (embedding + vocab) fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128 partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
@ -350,7 +361,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
partialOffload = max( partialOffload = max(
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16, 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
) )
case "qwen2": case "qwen2":
fullOffload = max( fullOffload = max(