2023-09-07 13:55:37 -04:00
|
|
|
package llm
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2024-05-31 20:00:49 -07:00
|
|
|
"cmp"
|
2023-09-07 13:55:37 -04:00
|
|
|
"encoding/binary"
|
2024-06-24 21:47:52 -07:00
|
|
|
"encoding/json"
|
2023-09-07 13:55:37 -04:00
|
|
|
"fmt"
|
|
|
|
"io"
|
2024-05-31 20:00:49 -07:00
|
|
|
"log/slog"
|
|
|
|
"slices"
|
2024-03-28 18:54:01 -07:00
|
|
|
"strings"
|
2024-05-31 20:00:49 -07:00
|
|
|
|
|
|
|
"golang.org/x/exp/maps"
|
2023-09-07 13:55:37 -04:00
|
|
|
)
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
type containerGGUF struct {
|
2024-03-06 21:01:51 -08:00
|
|
|
ByteOrder binary.ByteOrder
|
2023-10-23 09:33:13 -07:00
|
|
|
|
2023-09-07 13:55:37 -04:00
|
|
|
Version uint32
|
|
|
|
|
|
|
|
V1 struct {
|
|
|
|
NumTensor uint32
|
|
|
|
NumKV uint32
|
|
|
|
}
|
|
|
|
|
|
|
|
V2 struct {
|
|
|
|
NumTensor uint64
|
|
|
|
NumKV uint64
|
|
|
|
}
|
2024-03-06 21:01:51 -08:00
|
|
|
|
|
|
|
V3 struct {
|
|
|
|
NumTensor uint64
|
|
|
|
NumKV uint64
|
|
|
|
}
|
2024-06-24 21:47:52 -07:00
|
|
|
|
|
|
|
maxArraySize int
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *containerGGUF) canCollectArray(size int) bool {
|
|
|
|
return c.maxArraySize < 0 || size <= c.maxArraySize
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func (c *containerGGUF) Name() string {
|
2023-09-07 13:55:37 -04:00
|
|
|
return "gguf"
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
|
2024-06-11 15:55:44 -07:00
|
|
|
if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return nil, err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
var err error
|
2023-09-07 13:55:37 -04:00
|
|
|
switch c.Version {
|
|
|
|
case 1:
|
2024-03-28 18:54:01 -07:00
|
|
|
err = binary.Read(rs, c.ByteOrder, &c.V1)
|
|
|
|
case 2:
|
|
|
|
err = binary.Read(rs, c.ByteOrder, &c.V2)
|
2023-09-07 13:55:37 -04:00
|
|
|
default:
|
2024-03-28 18:54:01 -07:00
|
|
|
err = binary.Read(rs, c.ByteOrder, &c.V3)
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
model := newGGUF(c)
|
2024-03-09 12:28:36 -08:00
|
|
|
if err := model.Decode(rs); err != nil {
|
2023-09-07 13:55:37 -04:00
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return model, nil
|
|
|
|
}
|
|
|
|
|
2024-03-06 21:01:51 -08:00
|
|
|
const (
|
2024-03-28 18:54:01 -07:00
|
|
|
ggufTypeUint8 uint32 = iota
|
|
|
|
ggufTypeInt8
|
|
|
|
ggufTypeUint16
|
|
|
|
ggufTypeInt16
|
|
|
|
ggufTypeUint32
|
|
|
|
ggufTypeInt32
|
|
|
|
ggufTypeFloat32
|
|
|
|
ggufTypeBool
|
|
|
|
ggufTypeString
|
|
|
|
ggufTypeArray
|
|
|
|
ggufTypeUint64
|
|
|
|
ggufTypeInt64
|
|
|
|
ggufTypeFloat64
|
2024-03-06 21:01:51 -08:00
|
|
|
)
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
type gguf struct {
|
|
|
|
*containerGGUF
|
2024-03-06 21:01:51 -08:00
|
|
|
|
2024-03-13 11:03:56 -07:00
|
|
|
kv KV
|
|
|
|
tensors []*Tensor
|
2023-11-24 11:57:20 -08:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
parameters uint64
|
2024-06-03 09:49:13 -07:00
|
|
|
tensorOffset uint64
|
2024-06-24 21:47:52 -07:00
|
|
|
|
|
|
|
scratch [16 << 10]byte
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func newGGUF(container *containerGGUF) *gguf {
|
|
|
|
return &gguf{
|
|
|
|
containerGGUF: container,
|
2024-03-13 11:03:56 -07:00
|
|
|
kv: make(KV),
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-13 11:03:56 -07:00
|
|
|
func (llm *gguf) KV() KV {
|
|
|
|
return llm.kv
|
|
|
|
}
|
|
|
|
|
2024-04-03 15:00:31 -07:00
|
|
|
func (llm *gguf) Tensors() Tensors {
|
2024-06-03 09:49:13 -07:00
|
|
|
return Tensors{
|
|
|
|
Items: llm.tensors,
|
|
|
|
Offset: llm.tensorOffset,
|
|
|
|
}
|
2024-03-13 11:03:56 -07:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func (llm *gguf) numTensor() uint64 {
|
|
|
|
switch llm.Version {
|
|
|
|
case 1:
|
2023-11-08 19:55:46 -06:00
|
|
|
return uint64(llm.V1.NumTensor)
|
2024-03-28 18:54:01 -07:00
|
|
|
case 2:
|
|
|
|
return llm.V2.NumTensor
|
|
|
|
default:
|
|
|
|
return llm.V3.NumTensor
|
2023-11-08 19:55:46 -06:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func (llm *gguf) numKV() uint64 {
|
|
|
|
switch llm.Version {
|
|
|
|
case 1:
|
2023-09-07 13:55:37 -04:00
|
|
|
return uint64(llm.V1.NumKV)
|
2024-03-28 18:54:01 -07:00
|
|
|
case 2:
|
|
|
|
return llm.V2.NumKV
|
|
|
|
default:
|
|
|
|
return llm.V3.NumKV
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|
|
|
// decode key-values
|
|
|
|
for i := 0; uint64(i) < llm.numKV(); i++ {
|
|
|
|
k, err := readGGUFString(llm, rs)
|
|
|
|
if err != nil {
|
2024-03-06 21:01:51 -08:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
t, err := readGGUF[uint32](llm, rs)
|
|
|
|
if err != nil {
|
2024-03-06 21:01:51 -08:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
var v any
|
|
|
|
switch t {
|
|
|
|
case ggufTypeUint8:
|
|
|
|
v, err = readGGUF[uint8](llm, rs)
|
|
|
|
case ggufTypeInt8:
|
|
|
|
v, err = readGGUF[int8](llm, rs)
|
|
|
|
case ggufTypeUint16:
|
|
|
|
v, err = readGGUF[uint16](llm, rs)
|
|
|
|
case ggufTypeInt16:
|
|
|
|
v, err = readGGUF[int16](llm, rs)
|
|
|
|
case ggufTypeUint32:
|
|
|
|
v, err = readGGUF[uint32](llm, rs)
|
|
|
|
case ggufTypeInt32:
|
|
|
|
v, err = readGGUF[int32](llm, rs)
|
|
|
|
case ggufTypeUint64:
|
|
|
|
v, err = readGGUF[uint64](llm, rs)
|
|
|
|
case ggufTypeInt64:
|
|
|
|
v, err = readGGUF[int64](llm, rs)
|
|
|
|
case ggufTypeFloat32:
|
|
|
|
v, err = readGGUF[float32](llm, rs)
|
|
|
|
case ggufTypeFloat64:
|
|
|
|
v, err = readGGUF[float64](llm, rs)
|
|
|
|
case ggufTypeBool:
|
|
|
|
v, err = readGGUF[bool](llm, rs)
|
|
|
|
case ggufTypeString:
|
|
|
|
v, err = readGGUFString(llm, rs)
|
|
|
|
case ggufTypeArray:
|
|
|
|
v, err = readGGUFArray(llm, rs)
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("invalid type: %d", t)
|
2024-03-06 21:01:51 -08:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
if err != nil {
|
2024-03-06 21:01:51 -08:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-03-13 11:03:56 -07:00
|
|
|
llm.kv[k] = v
|
2024-03-06 21:01:51 -08:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
// decode tensors
|
2024-06-24 21:47:52 -07:00
|
|
|
for range llm.numTensor() {
|
2024-03-28 18:54:01 -07:00
|
|
|
name, err := readGGUFString(llm, rs)
|
2024-03-06 21:01:51 -08:00
|
|
|
if err != nil {
|
2024-06-24 21:47:52 -07:00
|
|
|
return fmt.Errorf("failed to read tensor name: %w", err)
|
2024-03-06 21:01:51 -08:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
// dims is the number of dimensions in the tensor
|
|
|
|
dims, err := readGGUF[uint32](llm, rs)
|
2023-09-07 13:55:37 -04:00
|
|
|
if err != nil {
|
2024-06-24 21:47:52 -07:00
|
|
|
return fmt.Errorf("failed to read tensor dimensions: %w", err)
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
shape := make([]uint64, dims)
|
2024-03-28 18:54:01 -07:00
|
|
|
for i := 0; uint32(i) < dims; i++ {
|
|
|
|
shape[i], err = readGGUF[uint64](llm, rs)
|
2023-09-07 13:55:37 -04:00
|
|
|
if err != nil {
|
2024-06-24 21:47:52 -07:00
|
|
|
return fmt.Errorf("failed to read tensor shape: %w", err)
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
kind, err := readGGUF[uint32](llm, rs)
|
2023-11-24 11:57:20 -08:00
|
|
|
if err != nil {
|
2024-06-24 21:47:52 -07:00
|
|
|
return fmt.Errorf("failed to read tensor kind: %w", err)
|
2023-11-08 19:55:46 -06:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
offset, err := readGGUF[uint64](llm, rs)
|
|
|
|
if err != nil {
|
2024-06-24 21:47:52 -07:00
|
|
|
return fmt.Errorf("failed to read tensor offset: %w", err)
|
2023-11-24 11:57:20 -08:00
|
|
|
}
|
|
|
|
|
2024-03-06 21:01:51 -08:00
|
|
|
tensor := Tensor{
|
|
|
|
Name: name,
|
2024-03-28 18:54:01 -07:00
|
|
|
Kind: kind,
|
|
|
|
Offset: offset,
|
2024-03-08 15:38:53 -08:00
|
|
|
Shape: shape[:],
|
2024-01-24 10:48:31 -08:00
|
|
|
}
|
2023-11-24 11:57:20 -08:00
|
|
|
|
2024-03-13 11:03:56 -07:00
|
|
|
llm.tensors = append(llm.tensors, &tensor)
|
2024-03-28 18:54:01 -07:00
|
|
|
llm.parameters += tensor.parameters()
|
2023-11-24 11:57:20 -08:00
|
|
|
}
|
|
|
|
|
2024-03-13 11:03:56 -07:00
|
|
|
// patch KV with parameter count
|
|
|
|
llm.kv["general.parameter_count"] = llm.parameters
|
|
|
|
|
|
|
|
alignment, ok := llm.kv["general.alignment"].(uint32)
|
2023-11-24 11:57:20 -08:00
|
|
|
if !ok {
|
|
|
|
alignment = 32
|
|
|
|
}
|
2023-11-08 19:55:46 -06:00
|
|
|
|
2024-06-03 09:49:13 -07:00
|
|
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
padding := ggufPadding(offset, int64(alignment))
|
2024-06-03 09:49:13 -07:00
|
|
|
llm.tensorOffset = uint64(offset + padding)
|
|
|
|
|
2024-03-13 11:03:56 -07:00
|
|
|
for _, tensor := range llm.tensors {
|
2024-06-24 21:47:52 -07:00
|
|
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to get current offset: %w", err)
|
2024-04-15 17:31:11 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
padding := ggufPadding(offset, int64(alignment))
|
2024-04-15 17:31:11 -07:00
|
|
|
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
|
2024-06-24 21:47:52 -07:00
|
|
|
return fmt.Errorf("failed to seek to init padding: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
|
|
|
|
return fmt.Errorf("failed to seek to tensor: %w", err)
|
2024-03-09 12:28:36 -08:00
|
|
|
}
|
2023-11-08 19:55:46 -06:00
|
|
|
}
|
|
|
|
|
2023-09-07 13:55:37 -04:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
|
|
|
|
var t T
|
|
|
|
err := binary.Read(r, llm.ByteOrder, &t)
|
|
|
|
return t, err
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
func writeGGUF[V any](w io.Writer, t uint32, v V) error {
|
|
|
|
if err := binary.Write(w, binary.LittleEndian, t); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
return binary.Write(w, binary.LittleEndian, v)
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
|
|
|
|
var length uint64
|
|
|
|
if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
|
|
|
var b bytes.Buffer
|
2024-03-28 18:54:01 -07:00
|
|
|
if _, err := io.CopyN(&b, r, int64(length)); err != nil {
|
2023-09-07 13:55:37 -04:00
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// gguf v1 strings are null-terminated
|
|
|
|
b.Truncate(b.Len() - 1)
|
|
|
|
|
|
|
|
return b.String(), nil
|
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
func discardGGUFString(llm *gguf, r io.Reader) error {
|
|
|
|
buf := llm.scratch[:8]
|
|
|
|
_, err := io.ReadFull(r, buf)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
size := int(llm.ByteOrder.Uint64(buf))
|
|
|
|
for size > 0 {
|
|
|
|
n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
size -= n
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
func readGGUFString(llm *gguf, r io.Reader) (string, error) {
|
2023-11-08 19:55:46 -06:00
|
|
|
if llm.Version == 1 {
|
2024-03-28 18:54:01 -07:00
|
|
|
return readGGUFV1String(llm, r)
|
2023-11-08 19:55:46 -06:00
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
buf := llm.scratch[:8]
|
|
|
|
_, err := io.ReadFull(r, buf)
|
|
|
|
if err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return "", err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
length := int(llm.ByteOrder.Uint64(buf))
|
|
|
|
if length > len(llm.scratch) {
|
|
|
|
buf = make([]byte, length)
|
|
|
|
} else {
|
|
|
|
buf = llm.scratch[:length]
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
2024-06-24 21:47:52 -07:00
|
|
|
clear(buf)
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
_, err = io.ReadFull(r, buf)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
return string(buf), nil
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
func writeGGUFString(w io.Writer, s string) error {
|
|
|
|
if err := binary.Write(w, binary.LittleEndian, ggufTypeString); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err := io.Copy(w, strings.NewReader(s))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
type array struct {
|
|
|
|
size int
|
|
|
|
values []any
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *array) MarshalJSON() ([]byte, error) {
|
|
|
|
return json.Marshal(a.values)
|
|
|
|
}
|
|
|
|
|
|
|
|
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
|
2024-03-28 18:54:01 -07:00
|
|
|
t, err := readGGUF[uint32](llm, r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
n, err := readGGUF[uint32](llm, r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
a := &array{size: int(n)}
|
|
|
|
if llm.canCollectArray(int(n)) {
|
|
|
|
a.values = make([]any, 0, int(n))
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := range n {
|
2024-03-28 18:54:01 -07:00
|
|
|
var e any
|
|
|
|
switch t {
|
|
|
|
case ggufTypeUint8:
|
|
|
|
e, err = readGGUF[uint8](llm, r)
|
|
|
|
case ggufTypeInt8:
|
|
|
|
e, err = readGGUF[int8](llm, r)
|
|
|
|
case ggufTypeUint16:
|
|
|
|
e, err = readGGUF[uint16](llm, r)
|
|
|
|
case ggufTypeInt16:
|
|
|
|
e, err = readGGUF[int16](llm, r)
|
|
|
|
case ggufTypeUint32:
|
|
|
|
e, err = readGGUF[uint32](llm, r)
|
|
|
|
case ggufTypeInt32:
|
|
|
|
e, err = readGGUF[int32](llm, r)
|
|
|
|
case ggufTypeUint64:
|
|
|
|
e, err = readGGUF[uint64](llm, r)
|
|
|
|
case ggufTypeInt64:
|
|
|
|
e, err = readGGUF[int64](llm, r)
|
|
|
|
case ggufTypeFloat32:
|
|
|
|
e, err = readGGUF[float32](llm, r)
|
|
|
|
case ggufTypeFloat64:
|
|
|
|
e, err = readGGUF[float64](llm, r)
|
|
|
|
case ggufTypeBool:
|
|
|
|
e, err = readGGUF[bool](llm, r)
|
|
|
|
case ggufTypeString:
|
|
|
|
e, err = readGGUFV1String(llm, r)
|
2023-09-07 13:55:37 -04:00
|
|
|
default:
|
2024-03-28 18:54:01 -07:00
|
|
|
return nil, fmt.Errorf("invalid array type: %d", t)
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
if a.values != nil {
|
|
|
|
a.values[i] = e
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
return a, nil
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
2023-11-08 19:55:46 -06:00
|
|
|
if llm.Version == 1 {
|
2024-03-28 18:54:01 -07:00
|
|
|
return readGGUFV1Array(llm, r)
|
2023-11-08 19:55:46 -06:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
t, err := readGGUF[uint32](llm, r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-03-28 18:54:01 -07:00
|
|
|
n, err := readGGUF[uint64](llm, r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
a := &array{size: int(n)}
|
|
|
|
if llm.canCollectArray(int(n)) {
|
|
|
|
a.values = make([]any, int(n))
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := range n {
|
2024-03-28 18:54:01 -07:00
|
|
|
var e any
|
|
|
|
switch t {
|
|
|
|
case ggufTypeUint8:
|
|
|
|
e, err = readGGUF[uint8](llm, r)
|
|
|
|
case ggufTypeInt8:
|
|
|
|
e, err = readGGUF[int8](llm, r)
|
|
|
|
case ggufTypeUint16:
|
|
|
|
e, err = readGGUF[uint16](llm, r)
|
|
|
|
case ggufTypeInt16:
|
|
|
|
e, err = readGGUF[int16](llm, r)
|
|
|
|
case ggufTypeUint32:
|
|
|
|
e, err = readGGUF[uint32](llm, r)
|
|
|
|
case ggufTypeInt32:
|
|
|
|
e, err = readGGUF[int32](llm, r)
|
|
|
|
case ggufTypeUint64:
|
|
|
|
e, err = readGGUF[uint64](llm, r)
|
|
|
|
case ggufTypeInt64:
|
|
|
|
e, err = readGGUF[int64](llm, r)
|
|
|
|
case ggufTypeFloat32:
|
|
|
|
e, err = readGGUF[float32](llm, r)
|
|
|
|
case ggufTypeFloat64:
|
|
|
|
e, err = readGGUF[float64](llm, r)
|
|
|
|
case ggufTypeBool:
|
|
|
|
e, err = readGGUF[bool](llm, r)
|
|
|
|
case ggufTypeString:
|
2024-06-24 21:47:52 -07:00
|
|
|
if a.values != nil {
|
|
|
|
e, err = readGGUFString(llm, r)
|
|
|
|
} else {
|
|
|
|
err = discardGGUFString(llm, r)
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
default:
|
2024-03-28 18:54:01 -07:00
|
|
|
return nil, fmt.Errorf("invalid array type: %d", t)
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
if a.values != nil {
|
|
|
|
a.values[i] = e
|
|
|
|
}
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
|
|
|
|
2024-06-24 21:47:52 -07:00
|
|
|
return a, nil
|
2023-09-07 13:55:37 -04:00
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
// writeGGUFArray writes a slice s of type E to the write with a gguf type of t
|
2024-05-31 20:00:49 -07:00
|
|
|
func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
|
|
|
if err := binary.Write(w, binary.LittleEndian, ggufTypeArray); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(w, binary.LittleEndian, t); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(w, binary.LittleEndian, uint64(len(s))); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
return binary.Write(w, binary.LittleEndian, s)
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
keys := maps.Keys(kv)
|
|
|
|
slices.Sort(keys)
|
|
|
|
|
|
|
|
for _, key := range keys {
|
|
|
|
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
slices.SortFunc(ts, func(a, b Tensor) int {
|
2024-05-31 20:00:49 -07:00
|
|
|
var i, j int
|
|
|
|
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
|
|
|
|
return cmp.Compare(a.Name, b.Name)
|
|
|
|
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
|
|
|
|
return cmp.Compare(a.Name, b.Name)
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
return cmp.Compare(i, j)
|
|
|
|
})
|
|
|
|
|
|
|
|
var s uint64
|
|
|
|
for _, t := range ts {
|
|
|
|
t.Offset = s
|
|
|
|
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
2024-05-31 20:00:49 -07:00
|
|
|
s += t.Size()
|
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
var alignment int64 = 32
|
|
|
|
for _, t := range ts {
|
|
|
|
if err := ggufWriteTensor(ws, t, alignment); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
2024-05-31 20:00:49 -07:00
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
return nil
|
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
|
|
|
slog.Debug(k, "type", fmt.Sprintf("%T", v))
|
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
|
|
|
|
return err
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, []byte(k)); err != nil {
|
|
|
|
return err
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
var err error
|
|
|
|
switch v := v.(type) {
|
|
|
|
case uint32:
|
|
|
|
err = writeGGUF(ws, ggufTypeUint32, v)
|
|
|
|
case float32:
|
|
|
|
err = writeGGUF(ws, ggufTypeFloat32, v)
|
|
|
|
case bool:
|
|
|
|
err = writeGGUF(ws, ggufTypeBool, v)
|
|
|
|
case string:
|
|
|
|
err = writeGGUFString(ws, v)
|
|
|
|
case []int32:
|
|
|
|
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
|
|
|
case []uint32:
|
|
|
|
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
|
|
|
case []float32:
|
|
|
|
err = writeGGUFArray(ws, ggufTypeFloat32, v)
|
|
|
|
case []string:
|
|
|
|
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
for _, e := range v {
|
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
|
2024-03-28 18:54:01 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
2024-05-31 20:00:49 -07:00
|
|
|
default:
|
|
|
|
return fmt.Errorf("improper type for '%s'", k)
|
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
2024-05-31 20:00:49 -07:00
|
|
|
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
|
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
|
|
|
|
return err
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, []byte(t.Name)); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, uint32(len(t.Shape))); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2024-06-07 08:55:46 -07:00
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
for i := range len(t.Shape) {
|
|
|
|
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
|
2024-06-07 08:55:46 -07:00
|
|
|
return err
|
|
|
|
}
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
if err := binary.Write(ws, binary.LittleEndian, t.Kind); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return binary.Write(ws, binary.LittleEndian, t.Offset)
|
|
|
|
}
|
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
|
2024-05-31 20:00:49 -07:00
|
|
|
offset, err := ws.Seek(0, io.SeekCurrent)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = t.WriteTo(ws)
|
|
|
|
return err
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
func ggufPadding(offset, align int64) int64 {
|
2024-04-15 17:31:11 -07:00
|
|
|
return (align - offset%align) % align
|
2024-03-28 18:54:01 -07:00
|
|
|
}
|