From cd22855ef868609d74c64516f9b9cf92f1c662c9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Jan 2024 10:48:31 -0800 Subject: [PATCH] refactor tensor read --- llm/gguf.go | 115 ++++++++++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 54 deletions(-) diff --git a/llm/gguf.go b/llm/gguf.go index cfcab758..436be42c 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -69,12 +69,65 @@ type tensor struct { name string kind uint32 offset uint64 - size uint64 // shape is the number of elements in each dimension shape [4]uint64 } +func (t tensor) blockSize() uint64 { + switch { + case t.kind < 2: + return 1 + case t.kind < 10: + return 32 + default: + return 256 + } +} + +func (t tensor) typeSize() uint64 { + blockSize := t.blockSize() + + switch t.kind { + case 0: // FP32 + return 4 + case 1: // FP16 + return 2 + case 2: // Q4_0 + return 2 + blockSize/2 + case 3: // Q4_1 + return 2 + 2 + blockSize/2 + case 6: // Q5_0 + return 2 + 4 + blockSize/2 + case 7: // Q5_1 + return 2 + 2 + 4 + blockSize/2 + case 8: // Q8_0 + return 2 + blockSize + case 9: // Q8_1 + return 4 + 4 + blockSize + case 10: // Q2_K + return blockSize/16 + blockSize/4 + 2 + 2 + case 11: // Q3_K + return blockSize/8 + blockSize/4 + 12 + 2 + case 12: // Q4_K + return 2 + 2 + 12 + blockSize/2 + case 13: // Q5_K + return 2 + 2 + 12 + blockSize/8 + blockSize/2 + case 14: // Q6_K + return blockSize/2 + blockSize/4 + blockSize/16 + 2 + default: + return 0 + } +} + +func (t tensor) parameters() uint64 { + return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3] +} + +func (t tensor) size() uint64 { + return t.parameters() * t.typeSize() / t.blockSize() +} + type ggufModel struct { *containerGGUF @@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { shape[i] = llm.readU64(rso) } - kind := llm.readU32(rso) - offset := llm.readU64(rso) - - var blockSize uint64 - switch { - case kind < 2: - blockSize = 1 - case kind < 10: - blockSize = 32 - default: - blockSize = 256 - } - - var typeSize uint64 - switch kind { - case 0: // FP32 - typeSize = 4 - case 1: // FP16 - typeSize = 2 - case 2: // Q4_0 - typeSize = 2 + blockSize/2 - case 3: // Q4_1 - typeSize = 2 + 2 + blockSize/2 - case 6: // Q5_0 - typeSize = 2 + 4 + blockSize/2 - case 7: // Q5_1 - typeSize = 2 + 2 + 4 + blockSize/2 - case 8: // Q8_0 - typeSize = 2 + blockSize - case 9: // Q8_1 - typeSize = 4 + 4 + blockSize - case 10: // Q2_K - typeSize = blockSize/16 + blockSize/4 + 2 + 2 - case 11: // Q3_K - typeSize = blockSize/8 + blockSize/4 + 12 + 2 - case 12: // Q4_K - typeSize = 2 + 2 + 12 + blockSize/2 - case 13: // Q5_K - typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2 - case 14: // Q6_K - typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2 - } - - parameters := shape[0] * shape[1] * shape[2] * shape[3] - size := parameters * typeSize / blockSize - - llm.tensors = append(llm.tensors, tensor{ + tensor := tensor{ name: name, - kind: kind, - offset: offset, - size: size, + kind: llm.readU32(rso), + offset: llm.readU64(rso), shape: shape, - }) + } - llm.parameters += parameters + llm.tensors = append(llm.tensors, tensor) + llm.parameters += tensor.parameters() } alignment, ok := llm.kv["general.alignment"].(uint32) @@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) for _, tensor := range llm.tensors { - padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1) + padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1) rso.Seek(padded, io.SeekCurrent) }