llama3.1 memory

This commit is contained in:
Michael Yang 2024-08-08 11:18:13 -07:00
parent 7d1c0047fa
commit 2003d60159

View file

@ -344,11 +344,13 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
switch llm.KV().Architecture() { switch llm.KV().Architecture() {
case "llama": case "llama":
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads)) fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab),
)
partialOffload = 4 * batch * embedding partialOffload = 4 * batch * embedding
partialOffload += max( partialOffload += max(
// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV), 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )