mllama cross attention
This commit is contained in:
parent
16f4eabe2d
commit
8c238e70ab
1 changed files with 24 additions and 0 deletions
24
llm/ggml.go
24
llm/ggml.go
|
@ -400,6 +400,30 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
case "mllama":
|
||||||
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
|
fullOffload = max(
|
||||||
|
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
||||||
|
// vocab graph
|
||||||
|
4*batch*(embedding+vocab),
|
||||||
|
)
|
||||||
|
|
||||||
|
var ropeFreqsCount uint64
|
||||||
|
if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
||||||
|
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||||
|
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
partialOffload = max(
|
||||||
|
4*(batch*
|
||||||
|
(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
|
||||||
|
ropeFreqsCount+
|
||||||
|
embeddingHeadsK*context*headsKV),
|
||||||
|
// vocab graph
|
||||||
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
|
)
|
||||||
case "gemma", "gemma2":
|
case "gemma", "gemma2":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
|
|
Loading…
Reference in a new issue