ollama/llm/gguf.go

663 lines
13 KiB
Go
Raw Normal View History

2023-09-07 13:55:37 -04:00
package llm
import (
"bytes"
2024-05-31 20:00:49 -07:00
"cmp"
2023-09-07 13:55:37 -04:00
"encoding/binary"
"encoding/json"
2023-09-07 13:55:37 -04:00
"fmt"
"io"
2024-05-31 20:00:49 -07:00
"log/slog"
"slices"
"strings"
2024-05-31 20:00:49 -07:00
"golang.org/x/exp/maps"
2023-09-07 13:55:37 -04:00
)
type containerGGUF struct {
ByteOrder binary.ByteOrder
2023-09-07 13:55:37 -04:00
Version uint32
V1 struct {
NumTensor uint32
NumKV uint32
}
V2 struct {
NumTensor uint64
NumKV uint64
}
V3 struct {
NumTensor uint64
NumKV uint64
}
maxArraySize int
}
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
2023-09-07 13:55:37 -04:00
}
func (c *containerGGUF) Name() string {
2023-09-07 13:55:37 -04:00
return "gguf"
}
func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
return nil, err
}
2023-09-07 13:55:37 -04:00
var err error
2023-09-07 13:55:37 -04:00
switch c.Version {
case 1:
err = binary.Read(rs, c.ByteOrder, &c.V1)
case 2:
err = binary.Read(rs, c.ByteOrder, &c.V2)
2023-09-07 13:55:37 -04:00
default:
err = binary.Read(rs, c.ByteOrder, &c.V3)
}
if err != nil {
return nil, err
2023-09-07 13:55:37 -04:00
}
model := newGGUF(c)
2024-03-09 12:28:36 -08:00
if err := model.Decode(rs); err != nil {
2023-09-07 13:55:37 -04:00
return nil, err
}
return model, nil
}
const (
ggufTypeUint8 uint32 = iota
ggufTypeInt8
ggufTypeUint16
ggufTypeInt16
ggufTypeUint32
ggufTypeInt32
ggufTypeFloat32
ggufTypeBool
ggufTypeString
ggufTypeArray
ggufTypeUint64
ggufTypeInt64
ggufTypeFloat64
)
2023-09-07 13:55:37 -04:00
type gguf struct {
*containerGGUF
2024-03-13 11:03:56 -07:00
kv KV
tensors []*Tensor
2023-11-24 11:57:20 -08:00
2024-05-31 20:00:49 -07:00
parameters uint64
tensorOffset uint64
scratch [16 << 10]byte
2023-09-07 13:55:37 -04:00
}
func newGGUF(container *containerGGUF) *gguf {
return &gguf{
containerGGUF: container,
2024-03-13 11:03:56 -07:00
kv: make(KV),
2023-09-07 13:55:37 -04:00
}
}
2024-03-13 11:03:56 -07:00
func (llm *gguf) KV() KV {
return llm.kv
}
2024-04-03 15:00:31 -07:00
func (llm *gguf) Tensors() Tensors {
return Tensors{
Items: llm.tensors,
Offset: llm.tensorOffset,
}
2024-03-13 11:03:56 -07:00
}
func (llm *gguf) numTensor() uint64 {
switch llm.Version {
case 1:
return uint64(llm.V1.NumTensor)
case 2:
return llm.V2.NumTensor
default:
return llm.V3.NumTensor
}
}
func (llm *gguf) numKV() uint64 {
switch llm.Version {
case 1:
2023-09-07 13:55:37 -04:00
return uint64(llm.V1.NumKV)
case 2:
return llm.V2.NumKV
default:
return llm.V3.NumKV
2023-09-07 13:55:37 -04:00
}
}
func (llm *gguf) Decode(rs io.ReadSeeker) error {
// decode key-values
for i := 0; uint64(i) < llm.numKV(); i++ {
k, err := readGGUFString(llm, rs)
if err != nil {
return err
}
t, err := readGGUF[uint32](llm, rs)
if err != nil {
return err
}
var v any
switch t {
case ggufTypeUint8:
v, err = readGGUF[uint8](llm, rs)
case ggufTypeInt8:
v, err = readGGUF[int8](llm, rs)
case ggufTypeUint16:
v, err = readGGUF[uint16](llm, rs)
case ggufTypeInt16:
v, err = readGGUF[int16](llm, rs)
case ggufTypeUint32:
v, err = readGGUF[uint32](llm, rs)
case ggufTypeInt32:
v, err = readGGUF[int32](llm, rs)
case ggufTypeUint64:
v, err = readGGUF[uint64](llm, rs)
case ggufTypeInt64:
v, err = readGGUF[int64](llm, rs)
case ggufTypeFloat32:
v, err = readGGUF[float32](llm, rs)
case ggufTypeFloat64:
v, err = readGGUF[float64](llm, rs)
case ggufTypeBool:
v, err = readGGUF[bool](llm, rs)
case ggufTypeString:
v, err = readGGUFString(llm, rs)
case ggufTypeArray:
v, err = readGGUFArray(llm, rs)
default:
return fmt.Errorf("invalid type: %d", t)
}
if err != nil {
return err
}
2024-03-13 11:03:56 -07:00
llm.kv[k] = v
}
// decode tensors
for range llm.numTensor() {
name, err := readGGUFString(llm, rs)
if err != nil {
return fmt.Errorf("failed to read tensor name: %w", err)
}
// dims is the number of dimensions in the tensor
dims, err := readGGUF[uint32](llm, rs)
2023-09-07 13:55:37 -04:00
if err != nil {
return fmt.Errorf("failed to read tensor dimensions: %w", err)
2023-09-07 13:55:37 -04:00
}
2024-05-31 20:00:49 -07:00
shape := make([]uint64, dims)
for i := 0; uint32(i) < dims; i++ {
shape[i], err = readGGUF[uint64](llm, rs)
2023-09-07 13:55:37 -04:00
if err != nil {
return fmt.Errorf("failed to read tensor shape: %w", err)
2023-09-07 13:55:37 -04:00
}
}
kind, err := readGGUF[uint32](llm, rs)
2023-11-24 11:57:20 -08:00
if err != nil {
return fmt.Errorf("failed to read tensor kind: %w", err)
}
offset, err := readGGUF[uint64](llm, rs)
if err != nil {
return fmt.Errorf("failed to read tensor offset: %w", err)
2023-11-24 11:57:20 -08:00
}
tensor := Tensor{
Name: name,
Kind: kind,
Offset: offset,
2024-03-08 15:38:53 -08:00
Shape: shape[:],
2024-01-24 10:48:31 -08:00
}
2023-11-24 11:57:20 -08:00
2024-03-13 11:03:56 -07:00
llm.tensors = append(llm.tensors, &tensor)
llm.parameters += tensor.parameters()
2023-11-24 11:57:20 -08:00
}
2024-03-13 11:03:56 -07:00
// patch KV with parameter count
llm.kv["general.parameter_count"] = llm.parameters
alignment, ok := llm.kv["general.alignment"].(uint32)
2023-11-24 11:57:20 -08:00
if !ok {
alignment = 32
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
2024-05-31 20:00:49 -07:00
padding := ggufPadding(offset, int64(alignment))
llm.tensorOffset = uint64(offset + padding)
2024-03-13 11:03:56 -07:00
for _, tensor := range llm.tensors {
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return fmt.Errorf("failed to get current offset: %w", err)
2024-04-15 17:31:11 -07:00
}
2024-05-31 20:00:49 -07:00
padding := ggufPadding(offset, int64(alignment))
2024-04-15 17:31:11 -07:00
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return fmt.Errorf("failed to seek to init padding: %w", err)
}
if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
return fmt.Errorf("failed to seek to tensor: %w", err)
2024-03-09 12:28:36 -08:00
}
}
2023-09-07 13:55:37 -04:00
return nil
}
func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
var t T
err := binary.Read(r, llm.ByteOrder, &t)
return t, err
2023-09-07 13:55:37 -04:00
}
2024-05-31 20:00:49 -07:00
func writeGGUF[V any](w io.Writer, t uint32, v V) error {
if err := binary.Write(w, binary.LittleEndian, t); err != nil {
return err
}
2023-09-07 13:55:37 -04:00
2024-05-31 20:00:49 -07:00
return binary.Write(w, binary.LittleEndian, v)
2023-09-07 13:55:37 -04:00
}
func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
var length uint64
if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
return "", err
}
2023-09-07 13:55:37 -04:00
var b bytes.Buffer
if _, err := io.CopyN(&b, r, int64(length)); err != nil {
2023-09-07 13:55:37 -04:00
return "", err
}
// gguf v1 strings are null-terminated
b.Truncate(b.Len() - 1)
return b.String(), nil
}
func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf)
if err != nil {
return err
}
size := int(llm.ByteOrder.Uint64(buf))
for size > 0 {
n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
if err != nil {
return err
}
size -= n
}
return nil
}
func readGGUFString(llm *gguf, r io.Reader) (string, error) {
if llm.Version == 1 {
return readGGUFV1String(llm, r)
}
buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf)
if err != nil {
return "", err
}
2023-09-07 13:55:37 -04:00
length := int(llm.ByteOrder.Uint64(buf))
if length > len(llm.scratch) {
buf = make([]byte, length)
} else {
buf = llm.scratch[:length]
2023-09-07 13:55:37 -04:00
}
clear(buf)
2023-09-07 13:55:37 -04:00
_, err = io.ReadFull(r, buf)
if err != nil {
return "", err
}
return string(buf), nil
2023-09-07 13:55:37 -04:00
}
2024-05-31 20:00:49 -07:00
func writeGGUFString(w io.Writer, s string) error {
if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
return err
}
2023-09-07 13:55:37 -04:00
2024-05-31 20:00:49 -07:00
if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
return err
}
_, err := io.Copy(w, strings.NewReader(s))
return err
}
type array struct {
size int
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
2023-09-07 13:55:37 -04:00
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
2023-09-07 13:55:37 -04:00
default:
return nil, fmt.Errorf("invalid array type: %d", t)
2023-09-07 13:55:37 -04:00
}
if err != nil {
return nil, err
}
if a.values != nil {
a.values[i] = e
}
2023-09-07 13:55:37 -04:00
}
return a, nil
2023-09-07 13:55:37 -04:00
}
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
if llm.Version == 1 {
return readGGUFV1Array(llm, r)
}
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
2023-09-07 13:55:37 -04:00
n, err := readGGUF[uint64](llm, r)
if err != nil {
return nil, err
}
2023-09-07 13:55:37 -04:00
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
if a.values != nil {
e, err = readGGUFString(llm, r)
} else {
err = discardGGUFString(llm, r)
}
2023-09-07 13:55:37 -04:00
default:
return nil, fmt.Errorf("invalid array type: %d", t)
2023-09-07 13:55:37 -04:00
}
if err != nil {
return nil, err
}
if a.values != nil {
a.values[i] = e
}
2023-09-07 13:55:37 -04:00
}
return a, nil
2023-09-07 13:55:37 -04:00
}
2024-07-08 16:59:48 -07:00
// writeGGUFArray writes a slice s of type E to the write with a gguf type of t
2024-05-31 20:00:49 -07:00
func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(w, binary.LittleEndian, t); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
return err
}
2024-07-08 16:59:48 -07:00
return binary.Write(w, binary.LittleEndian, s)
}
2024-07-08 16:59:48 -07:00
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
keys := maps.Keys(kv)
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
return err
}
}
slices.SortStableFunc(ts, func(a, b Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j)
}
2024-05-31 20:00:49 -07:00
})
var s uint64
for _, t := range ts {
t.Offset = s
if err := ggufWriteTensorInfo(ws, t); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
s += t.Size()
}
2024-05-31 20:00:49 -07:00
var alignment int64 = 32
for _, t := range ts {
if err := ggufWriteTensor(ws, t, alignment); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
}
2024-05-31 20:00:49 -07:00
return nil
}
2024-05-31 20:00:49 -07:00
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
slog.Debug(k, "type", fmt.Sprintf("%T", v))
if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
var err error
switch v := v.(type) {
case uint32:
err = writeGGUF(ws, ggufTypeUint32, v)
case float32:
err = writeGGUF(ws, ggufTypeFloat32, v)
case bool:
err = writeGGUF(ws, ggufTypeBool, v)
case string:
err = writeGGUFString(ws, v)
case []int32:
err = writeGGUFArray(ws, ggufTypeInt32, v)
case []uint32:
err = writeGGUFArray(ws, ggufTypeUint32, v)
case []float32:
err = writeGGUFArray(ws, ggufTypeFloat32, v)
case []string:
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
for _, e := range v {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
2024-05-31 20:00:49 -07:00
default:
return fmt.Errorf("improper type for '%s'", k)
}
2024-05-31 20:00:49 -07:00
return err
}
2024-07-08 16:59:48 -07:00
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
2024-05-31 20:00:49 -07:00
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
return err
}
2024-05-31 20:00:49 -07:00
for i := range len(t.Shape) {
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
return err
}
}
2024-05-31 20:00:49 -07:00
if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
return err
}
return binary.Write(ws, binary.LittleEndian, t.Offset)
}
2024-07-08 16:59:48 -07:00
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
2024-05-31 20:00:49 -07:00
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
return err
}
_, err = t.WriteTo(ws)
return err
}
2024-05-31 20:00:49 -07:00
func ggufPadding(offset, align int64) int64 {
2024-04-15 17:31:11 -07:00
return (align - offset%align) % align
}