ollama/llm/gguf.go

465 lines
8.7 KiB
Go
Raw Normal View History

2023-09-07 13:55:37 -04:00
package llm
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/jmorganca/ollama/format"
2023-09-07 13:55:37 -04:00
)
type containerGGUF struct {
bo binary.ByteOrder
2023-09-07 13:55:37 -04:00
Version uint32
V1 struct {
NumTensor uint32
NumKV uint32
}
V2 struct {
NumTensor uint64
NumKV uint64
}
}
func (c *containerGGUF) Name() string {
return "gguf"
}
2023-11-29 10:31:58 -08:00
func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) {
binary.Read(rso, c.bo, &c.Version)
2023-09-07 13:55:37 -04:00
switch c.Version {
case 1:
2023-11-29 10:31:58 -08:00
binary.Read(rso, c.bo, &c.V1)
2023-09-07 13:55:37 -04:00
default:
2023-11-29 10:31:58 -08:00
binary.Read(rso, c.bo, &c.V2)
2023-09-07 13:55:37 -04:00
}
model := newGGUFModel(c)
2023-11-29 10:31:58 -08:00
if err := model.Decode(rso); err != nil {
2023-09-07 13:55:37 -04:00
return nil, err
}
return model, nil
}
const (
ggufTypeUint8 uint32 = iota
ggufTypeInt8
ggufTypeUint16
ggufTypeInt16
ggufTypeUint32
ggufTypeInt32
ggufTypeFloat32
ggufTypeBool
ggufTypeString
ggufTypeArray
ggufTypeUint64
ggufTypeInt64
ggufTypeFloat64
)
type kv map[string]any
2023-11-24 11:57:20 -08:00
type tensor struct {
name string
kind uint32
offset uint64
size uint64
// shape is the number of elements in each dimension
shape [4]uint64
}
2023-09-07 13:55:37 -04:00
type ggufModel struct {
*containerGGUF
2023-11-24 11:57:20 -08:00
2023-09-07 13:55:37 -04:00
kv
2023-11-24 11:57:20 -08:00
tensors []tensor
parameters uint64
2023-09-07 13:55:37 -04:00
}
func newGGUFModel(container *containerGGUF) *ggufModel {
return &ggufModel{
containerGGUF: container,
kv: make(kv),
}
}
func (llm *ggufModel) NumTensor() uint64 {
if llm.Version == 1 {
return uint64(llm.V1.NumTensor)
}
return llm.V2.NumTensor
}
2023-09-07 13:55:37 -04:00
func (llm *ggufModel) NumKV() uint64 {
if llm.Version == 1 {
return uint64(llm.V1.NumKV)
}
return llm.V2.NumKV
}
func (llm *ggufModel) ModelFamily() string {
2023-11-29 10:54:23 -08:00
if t, ok := llm.kv["general.architecture"].(string); ok {
return t
2023-09-07 13:55:37 -04:00
}
return "unknown"
2023-09-07 13:55:37 -04:00
}
func (llm *ggufModel) ModelType() string {
if llm.parameters > 0 {
return format.HumanNumber(llm.parameters)
}
2023-10-02 19:52:25 -07:00
return "unknown"
2023-09-07 13:55:37 -04:00
}
func (llm *ggufModel) FileType() string {
2023-11-29 10:54:23 -08:00
if t, ok := llm.kv["general.file_type"].(uint32); ok {
return fileType(t)
2023-09-07 13:55:37 -04:00
}
2023-10-02 19:52:25 -07:00
return "unknown"
2023-09-07 13:55:37 -04:00
}
2023-11-29 10:31:58 -08:00
func (llm *ggufModel) Decode(rso *readSeekOffset) error {
// decode key-values
2023-09-07 13:55:37 -04:00
for i := 0; uint64(i) < llm.NumKV(); i++ {
2023-11-29 10:31:58 -08:00
k, err := llm.readString(rso)
2023-09-07 13:55:37 -04:00
if err != nil {
return err
}
2023-11-29 10:31:58 -08:00
vtype := llm.readU32(rso)
2023-09-07 13:55:37 -04:00
var v any
switch vtype {
case ggufTypeUint8:
2023-11-29 10:31:58 -08:00
v = llm.readU8(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeInt8:
2023-11-29 10:31:58 -08:00
v = llm.readI8(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeUint16:
2023-11-29 10:31:58 -08:00
v = llm.readU16(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeInt16:
2023-11-29 10:31:58 -08:00
v = llm.readI16(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeUint32:
2023-11-29 10:31:58 -08:00
v = llm.readU32(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeInt32:
2023-11-29 10:31:58 -08:00
v = llm.readI32(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeUint64:
2023-11-29 10:31:58 -08:00
v = llm.readU64(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeInt64:
2023-11-29 10:31:58 -08:00
v = llm.readI64(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeFloat32:
2023-11-29 10:31:58 -08:00
v = llm.readF32(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeFloat64:
2023-11-29 10:31:58 -08:00
v = llm.readF64(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeBool:
2023-11-29 10:31:58 -08:00
v = llm.readBool(rso)
2023-09-07 13:55:37 -04:00
case ggufTypeString:
2023-11-29 10:31:58 -08:00
s, err := llm.readString(rso)
2023-09-07 13:55:37 -04:00
if err != nil {
return err
}
v = s
case ggufTypeArray:
2023-11-29 10:31:58 -08:00
a, err := llm.readArray(rso)
2023-09-07 13:55:37 -04:00
if err != nil {
return err
}
v = a
default:
return fmt.Errorf("invalid type: %d", vtype)
}
llm.kv[k] = v
}
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
2023-11-29 10:31:58 -08:00
name, err := llm.readString(rso)
2023-11-24 11:57:20 -08:00
if err != nil {
return err
}
2023-11-29 10:54:23 -08:00
// dims is the number of dimensions in the tensor
2023-11-29 10:31:58 -08:00
dims := llm.readU32(rso)
2023-11-24 11:57:20 -08:00
shape := [4]uint64{1, 1, 1, 1}
for i := 0; uint32(i) < dims; i++ {
2023-11-29 10:31:58 -08:00
shape[i] = llm.readU64(rso)
2023-11-24 11:57:20 -08:00
}
2023-11-29 10:31:58 -08:00
kind := llm.readU32(rso)
offset := llm.readU64(rso)
2023-11-24 11:57:20 -08:00
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
2023-11-24 11:57:20 -08:00
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
2023-11-24 11:57:20 -08:00
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name,
kind: kind,
offset: offset,
size: size,
shape: shape,
})
llm.parameters += parameters
}
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
2023-11-29 10:31:58 -08:00
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
2023-11-24 11:57:20 -08:00
for _, tensor := range llm.tensors {
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
2023-11-29 10:31:58 -08:00
rso.Seek(padded, io.SeekCurrent)
}
2023-09-07 13:55:37 -04:00
return nil
}
func (llm *ggufModel) NumLayers() int64 {
value, exists := llm.kv[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
if !exists {
return 0
}
v := value.(uint32)
return int64(v)
}
func (llm ggufModel) readU8(r io.Reader) uint8 {
2023-09-07 13:55:37 -04:00
var u8 uint8
binary.Read(r, llm.bo, &u8)
2023-09-07 13:55:37 -04:00
return u8
}
func (llm ggufModel) readI8(r io.Reader) int8 {
2023-09-07 13:55:37 -04:00
var i8 int8
binary.Read(r, llm.bo, &i8)
2023-09-07 13:55:37 -04:00
return i8
}
func (llm ggufModel) readU16(r io.Reader) uint16 {
2023-09-07 13:55:37 -04:00
var u16 uint16
binary.Read(r, llm.bo, &u16)
2023-09-07 13:55:37 -04:00
return u16
}
func (llm ggufModel) readI16(r io.Reader) int16 {
2023-09-07 13:55:37 -04:00
var i16 int16
binary.Read(r, llm.bo, &i16)
2023-09-07 13:55:37 -04:00
return i16
}
func (llm ggufModel) readU32(r io.Reader) uint32 {
2023-09-07 13:55:37 -04:00
var u32 uint32
binary.Read(r, llm.bo, &u32)
2023-09-07 13:55:37 -04:00
return u32
}
func (llm ggufModel) readI32(r io.Reader) int32 {
2023-09-07 13:55:37 -04:00
var i32 int32
binary.Read(r, llm.bo, &i32)
2023-09-07 13:55:37 -04:00
return i32
}
func (llm ggufModel) readU64(r io.Reader) uint64 {
2023-09-07 13:55:37 -04:00
var u64 uint64
binary.Read(r, llm.bo, &u64)
2023-09-07 13:55:37 -04:00
return u64
}
func (llm ggufModel) readI64(r io.Reader) int64 {
2023-09-07 13:55:37 -04:00
var i64 int64
binary.Read(r, llm.bo, &i64)
2023-09-07 13:55:37 -04:00
return i64
}
func (llm ggufModel) readF32(r io.Reader) float32 {
2023-09-07 13:55:37 -04:00
var f32 float32
binary.Read(r, llm.bo, &f32)
2023-09-07 13:55:37 -04:00
return f32
}
func (llm ggufModel) readF64(r io.Reader) float64 {
2023-09-07 13:55:37 -04:00
var f64 float64
binary.Read(r, llm.bo, &f64)
2023-09-07 13:55:37 -04:00
return f64
}
func (llm ggufModel) readBool(r io.Reader) bool {
2023-09-07 13:55:37 -04:00
var b bool
binary.Read(r, llm.bo, &b)
2023-09-07 13:55:37 -04:00
return b
}
func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
2023-09-07 13:55:37 -04:00
var nameLength uint32
binary.Read(r, llm.bo, &nameLength)
2023-09-07 13:55:37 -04:00
var b bytes.Buffer
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
return "", err
}
// gguf v1 strings are null-terminated
b.Truncate(b.Len() - 1)
return b.String(), nil
}
func (llm ggufModel) readString(r io.Reader) (string, error) {
if llm.Version == 1 {
return llm.readStringV1(r)
}
2023-09-07 13:55:37 -04:00
var nameLength uint64
binary.Read(r, llm.bo, &nameLength)
2023-09-07 13:55:37 -04:00
var b bytes.Buffer
if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
return "", err
}
return b.String(), nil
}
func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
atype := llm.readU32(r)
n := llm.readU32(r)
for i := 0; uint32(i) < n; i++ {
switch atype {
case ggufTypeUint8:
arr = append(arr, llm.readU8(r))
case ggufTypeInt8:
2023-11-22 11:40:30 -08:00
arr = append(arr, llm.readI8(r))
2023-09-07 13:55:37 -04:00
case ggufTypeUint16:
arr = append(arr, llm.readU16(r))
case ggufTypeInt16:
arr = append(arr, llm.readI16(r))
case ggufTypeUint32:
arr = append(arr, llm.readU32(r))
case ggufTypeInt32:
arr = append(arr, llm.readI32(r))
case ggufTypeFloat32:
arr = append(arr, llm.readF32(r))
case ggufTypeBool:
arr = append(arr, llm.readBool(r))
case ggufTypeString:
s, err := llm.readStringV1(r)
if err != nil {
return nil, err
}
arr = append(arr, s)
default:
return nil, fmt.Errorf("invalid array type: %d", atype)
}
}
return
}
func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
if llm.Version == 1 {
return llm.readArrayV1(r)
}
2023-09-07 13:55:37 -04:00
atype := llm.readU32(r)
n := llm.readU64(r)
for i := 0; uint64(i) < n; i++ {
switch atype {
case ggufTypeUint8:
arr = append(arr, llm.readU8(r))
case ggufTypeInt8:
2023-11-22 11:40:30 -08:00
arr = append(arr, llm.readI8(r))
2023-09-07 13:55:37 -04:00
case ggufTypeUint16:
arr = append(arr, llm.readU16(r))
case ggufTypeInt16:
arr = append(arr, llm.readI16(r))
case ggufTypeUint32:
arr = append(arr, llm.readU32(r))
case ggufTypeInt32:
arr = append(arr, llm.readI32(r))
case ggufTypeUint64:
arr = append(arr, llm.readU64(r))
case ggufTypeInt64:
arr = append(arr, llm.readI64(r))
case ggufTypeFloat32:
arr = append(arr, llm.readF32(r))
case ggufTypeFloat64:
arr = append(arr, llm.readF64(r))
case ggufTypeBool:
arr = append(arr, llm.readBool(r))
case ggufTypeString:
s, err := llm.readString(r)
if err != nil {
return nil, err
}
arr = append(arr, s)
default:
return nil, fmt.Errorf("invalid array type: %d", atype)
}
}
return
}