diff --git a/llm/ggla.go b/llm/ggla.go index e22dd59f..21a386f8 100644 --- a/llm/ggla.go +++ b/llm/ggla.go @@ -15,8 +15,8 @@ func (c *ContainerGGLA) Name() string { return "ggla" } -func (c *ContainerGGLA) Decode(rso *readSeekOffset) (model, error) { - binary.Read(rso, binary.LittleEndian, &c.version) +func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) { + binary.Read(rs, binary.LittleEndian, &c.version) switch c.version { case 1: @@ -25,7 +25,7 @@ func (c *ContainerGGLA) Decode(rso *readSeekOffset) (model, error) { } model := newModelGGLA(c) - err := model.decode(rso) + err := model.decode(rs) return model, err } @@ -43,39 +43,39 @@ func newModelGGLA(container *ContainerGGLA) *ModelGGLA { } } -func (m *ModelGGLA) decode(rso *readSeekOffset) error { +func (m *ModelGGLA) decode(rs io.ReadSeeker) error { var r uint32 - if err := binary.Read(rso, binary.LittleEndian, &r); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &r); err != nil { return err } m.kv["r"] = r var alpha uint32 - if err := binary.Read(rso, binary.LittleEndian, &alpha); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &alpha); err != nil { return err } m.kv["alpha"] = alpha for { var dims uint32 - if err := binary.Read(rso, binary.LittleEndian, &dims); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil { return err } var namesize uint32 - if err := binary.Read(rso, binary.LittleEndian, &namesize); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil { return err } var t Tensor - if err := binary.Read(rso, binary.LittleEndian, &t.Kind); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &t.Kind); err != nil { return err } t.Shape = make([]uint64, dims) for i := 0; uint32(i) < dims; i++ { var shape32 uint32 - if err := binary.Read(rso, binary.LittleEndian, &shape32); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &shape32); err != nil { return err } @@ -87,19 +87,29 @@ func (m *ModelGGLA) decode(rso *readSeekOffset) error { slices.Reverse(t.Shape) name := make([]byte, namesize) - if err := binary.Read(rso, binary.LittleEndian, &name); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &name); err != nil { return err } t.Name = string(name) - if _, err := rso.Seek((rso.offset+31)&-32, io.SeekStart); err != nil { + offset, err := rs.Seek(0, io.SeekCurrent) + if err != nil { return err } - t.Offset = uint64(rso.offset) + if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil { + return err + } - if _, err := rso.Seek(int64(t.Size()), io.SeekCurrent); err != nil { + offset, err = rs.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + + t.Offset = uint64(offset) + + if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil { return err } diff --git a/llm/ggml.go b/llm/ggml.go index 88cd9e13..b7f29768 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -103,7 +103,7 @@ type model interface { type container interface { Name() string - Decode(*readSeekOffset) (model, error) + Decode(io.ReadSeeker) (model, error) } const ( @@ -122,11 +122,9 @@ const ( var ErrUnsupportedFormat = errors.New("unsupported model format") -func DecodeGGML(r io.ReadSeeker) (*GGML, error) { - ro := readSeekOffset{ReadSeeker: r} - +func DecodeGGML(rs io.ReadSeeker) (*GGML, error) { var magic uint32 - if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil { + if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { return nil, err } @@ -144,38 +142,22 @@ func DecodeGGML(r io.ReadSeeker) (*GGML, error) { return nil, errors.New("invalid file magic") } - model, err := c.Decode(&ro) + model, err := c.Decode(rs) if errors.Is(err, io.EOF) { // noop } else if err != nil { return nil, err } + offset, err := rs.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + // final model type return &GGML{ container: c, model: model, - Size: ro.offset, + Size: offset, }, nil } - -type readSeekOffset struct { - io.ReadSeeker - offset int64 -} - -func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) { - offset, err := rso.ReadSeeker.Seek(offset, whence) - if err != nil { - return 0, err - } - - rso.offset = offset - return offset, nil -} - -func (rso *readSeekOffset) Read(p []byte) (int, error) { - n, err := rso.ReadSeeker.Read(p) - rso.offset += int64(n) - return n, err -} diff --git a/llm/gguf.go b/llm/gguf.go index 61c55148..8c983095 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -42,18 +42,18 @@ func (c *ContainerGGUF) Name() string { return "gguf" } -func (c *ContainerGGUF) Decode(rso *readSeekOffset) (model, error) { - binary.Read(rso, c.ByteOrder, &c.Version) +func (c *ContainerGGUF) Decode(rs io.ReadSeeker) (model, error) { + binary.Read(rs, c.ByteOrder, &c.Version) switch c.Version { case 1: - binary.Read(rso, c.ByteOrder, &c.V1) + binary.Read(rs, c.ByteOrder, &c.V1) default: - binary.Read(rso, c.ByteOrder, &c.V2) + binary.Read(rs, c.ByteOrder, &c.V2) } model := NewGGUFModel(c) - if err := model.Decode(rso); err != nil { + if err := model.Decode(rs); err != nil { return nil, err } @@ -633,49 +633,49 @@ func (llm *GGUFModel) writeString(f *os.File, s string) error { return nil } -func (llm *GGUFModel) Decode(rso *readSeekOffset) error { +func (llm *GGUFModel) Decode(rs io.ReadSeeker) error { // decode key-values for i := 0; uint64(i) < llm.NumKV(); i++ { - k, err := llm.readString(rso) + k, err := llm.readString(rs) if err != nil { return err } - vtype := llm.readU32(rso) + vtype := llm.readU32(rs) var v any switch vtype { case GGUFTypeUint8: - v = llm.readU8(rso) + v = llm.readU8(rs) case GGUFTypeInt8: - v = llm.readI8(rso) + v = llm.readI8(rs) case GGUFTypeUint16: - v = llm.readU16(rso) + v = llm.readU16(rs) case GGUFTypeInt16: - v = llm.readI16(rso) + v = llm.readI16(rs) case GGUFTypeUint32: - v = llm.readU32(rso) + v = llm.readU32(rs) case GGUFTypeInt32: - v = llm.readI32(rso) + v = llm.readI32(rs) case GGUFTypeUint64: - v = llm.readU64(rso) + v = llm.readU64(rs) case GGUFTypeInt64: - v = llm.readI64(rso) + v = llm.readI64(rs) case GGUFTypeFloat32: - v = llm.readF32(rso) + v = llm.readF32(rs) case GGUFTypeFloat64: - v = llm.readF64(rso) + v = llm.readF64(rs) case GGUFTypeBool: - v = llm.readBool(rso) + v = llm.readBool(rs) case GGUFTypeString: - s, err := llm.readString(rso) + s, err := llm.readString(rs) if err != nil { return err } v = s case GGUFTypeArray: - a, err := llm.readArray(rso) + a, err := llm.readArray(rs) if err != nil { return err } @@ -690,23 +690,23 @@ func (llm *GGUFModel) Decode(rso *readSeekOffset) error { // decode tensors for i := 0; uint64(i) < llm.NumTensor(); i++ { - name, err := llm.readString(rso) + name, err := llm.readString(rs) if err != nil { return err } // dims is the number of dimensions in the tensor - dims := llm.readU32(rso) + dims := llm.readU32(rs) shape := [4]uint64{1, 1, 1, 1} for i := 0; uint32(i) < dims; i++ { - shape[i] = llm.readU64(rso) + shape[i] = llm.readU64(rs) } tensor := Tensor{ Name: name, - Kind: llm.readU32(rso), - Offset: llm.readU64(rso), + Kind: llm.readU32(rs), + Offset: llm.readU64(rs), Shape: shape[:], } @@ -719,10 +719,20 @@ func (llm *GGUFModel) Decode(rso *readSeekOffset) error { alignment = 32 } - rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) + offset, err := rs.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + + if _, err := rs.Seek(int64(alignment)-offset%int64(alignment), io.SeekCurrent); err != nil { + return err + } + for _, tensor := range llm.Tensors { padded := (int64(tensor.Size()) + int64(alignment) - 1) & ^(int64(alignment) - 1) - rso.Seek(padded, io.SeekCurrent) + if _, err := rs.Seek(padded, io.SeekCurrent); err != nil { + return err + } } return nil