Merge pull request #3600 from ollama/mxyng/mixtral
This commit is contained in:
commit
786f3a1c44
1 changed files with 12 additions and 1 deletions
11
llm/ggml.go
11
llm/ggml.go
|
@ -330,6 +330,8 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||||
headsKV := llm.KV().HeadCountKV()
|
headsKV := llm.KV().HeadCountKV()
|
||||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||||
|
|
||||||
|
layers := llm.Tensors().Layers()
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch llm.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||||
|
@ -339,6 +341,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
|
||||||
|
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||||
|
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||||
|
partialOffload = max(
|
||||||
|
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||||
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||||
|
)
|
||||||
|
}
|
||||||
case "gemma":
|
case "gemma":
|
||||||
fullOffload = 4 * batch * (embedding + vocab)
|
fullOffload = 4 * batch * (embedding + vocab)
|
||||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||||
|
|
Loading…
Reference in a new issue