refactor kv estimation
This commit is contained in:
parent
8c238e70ab
commit
d07cf41a97
2 changed files with 20 additions and 8 deletions
17
llm/ggml.go
17
llm/ggml.go
|
@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
|
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
|
||||||
embedding := llm.KV().EmbeddingLength()
|
embedding := llm.KV().EmbeddingLength()
|
||||||
heads := llm.KV().HeadCount()
|
heads := llm.KV().HeadCount()
|
||||||
headsKV := llm.KV().HeadCountKV()
|
headsKV := llm.KV().HeadCountKV()
|
||||||
|
@ -368,9 +368,12 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||||
|
|
||||||
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
||||||
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
||||||
|
embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
|
||||||
|
|
||||||
layers := llm.Tensors().Layers()
|
layers := llm.Tensors().Layers()
|
||||||
|
|
||||||
|
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch llm.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
|
@ -403,6 +406,18 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
|
if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||||
|
kv = headsKV *
|
||||||
|
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||||
|
(2* // sizeof(float16)
|
||||||
|
(llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||||
|
context +
|
||||||
|
4* // sizeof(float32)
|
||||||
|
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||||
|
visionTokens*
|
||||||
|
tiles)
|
||||||
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
||||||
// vocab graph
|
// vocab graph
|
||||||
|
|
|
@ -123,13 +123,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||||
slog.Warn("model missing blk.0 layer size")
|
slog.Warn("model missing blk.0 layer size")
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv
|
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||||
var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV()
|
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
|
||||||
layerSize += kv / ggml.KV().BlockCount()
|
|
||||||
|
|
||||||
graphPartialOffload, graphFullOffload = ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
|
||||||
if graphPartialOffload == 0 {
|
if graphPartialOffload == 0 {
|
||||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||||
}
|
}
|
||||||
|
@ -137,6 +131,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KV is proportional to the number of layers
|
||||||
|
layerSize += kv / ggml.KV().BlockCount()
|
||||||
|
|
||||||
// on metal there's no partial offload overhead
|
// on metal there's no partial offload overhead
|
||||||
if gpus[0].Library == "metal" {
|
if gpus[0].Library == "metal" {
|
||||||
graphPartialOffload = graphFullOffload
|
graphPartialOffload = graphFullOffload
|
||||||
|
|
Loading…
Reference in a new issue