From 8c238e70abe715ebe099657d110ee3a00876cc53 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 31 Oct 2024 13:40:06 -0700 Subject: [PATCH] mllama cross attention --- llm/ggml.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/llm/ggml.go b/llm/ggml.go index e857d4b8..9cf9172e 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -400,6 +400,30 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16), ) } + case "mllama": + var visionTokens, tiles uint64 = 1601, 4 + + fullOffload = max( + 4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)), + // vocab graph + 4*batch*(embedding+vocab), + ) + + var ropeFreqsCount uint64 + if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok { + if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok { + ropeFreqsCount = ropeFreqsWeights.parameters() + } + } + + partialOffload = max( + 4*(batch* + (2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+ + ropeFreqsCount+ + embeddingHeadsK*context*headsKV), + // vocab graph + 4*batch*(embedding+vocab)+embedding*vocab*105/128, + ) case "gemma", "gemma2": fullOffload = max( 4*batch*(embedding+vocab),