Merge pull request #3600 from ollama/mxyng/mixtral
This commit is contained in:
commit
786f3a1c44
1 changed files with 12 additions and 1 deletions
11
llm/ggml.go
11
llm/ggml.go
|
@ -330,6 +330,8 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||
|
@ -339,6 +341,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
|
||||
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
|
||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||
partialOffload = max(
|
||||
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||
)
|
||||
}
|
||||
case "gemma":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||
|
|
Loading…
Reference in a new issue