151 lines
3.1 KiB
Go
151 lines
3.1 KiB
Go
package convert
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/d4l3k/go-bfloat16"
|
|
"github.com/x448/float16"
|
|
"golang.org/x/exp/maps"
|
|
)
|
|
|
|
type safetensorMetadata struct {
|
|
Type string `json:"dtype"`
|
|
Shape []uint64 `json:"shape"`
|
|
Offsets []int64 `json:"data_offsets"`
|
|
}
|
|
|
|
func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
|
|
var ts []Tensor
|
|
for _, p := range ps {
|
|
f, err := fsys.Open(p)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
|
|
var n int64
|
|
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b := bytes.NewBuffer(make([]byte, 0, n))
|
|
if _, err = io.CopyN(b, f, n); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var headers map[string]safetensorMetadata
|
|
if err := json.NewDecoder(b).Decode(&headers); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keys := maps.Keys(headers)
|
|
slices.Sort(keys)
|
|
|
|
for _, key := range keys {
|
|
if value := headers[key]; value.Type != "" {
|
|
ts = append(ts, safetensor{
|
|
fs: fsys,
|
|
path: p,
|
|
dtype: value.Type,
|
|
offset: safetensorsPad(n, value.Offsets[0]),
|
|
size: safetensorsPad(n, value.Offsets[1]) - safetensorsPad(n, value.Offsets[0]),
|
|
tensorBase: &tensorBase{
|
|
name: replacer.Replace(key),
|
|
shape: value.Shape,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return ts, nil
|
|
}
|
|
|
|
// safetensorsPad returns the padded size of the safetensors file given a length n and offset s
|
|
func safetensorsPad(n, offset int64) int64 {
|
|
return 8 + n + offset
|
|
}
|
|
|
|
type safetensor struct {
|
|
fs fs.FS
|
|
path string
|
|
dtype string
|
|
offset int64
|
|
size int64
|
|
*tensorBase
|
|
}
|
|
|
|
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|
f, err := st.fs.Open(st.path)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer f.Close()
|
|
|
|
if seeker, ok := f.(io.Seeker); ok {
|
|
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
|
return 0, err
|
|
}
|
|
} else {
|
|
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
var f32s []float32
|
|
switch st.dtype {
|
|
case "F32":
|
|
f32s = make([]float32, st.size/4)
|
|
if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
|
|
return 0, err
|
|
}
|
|
case "F16":
|
|
u16s := make([]uint16, st.size/2)
|
|
if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
f32s = make([]float32, len(u16s))
|
|
for i := range u16s {
|
|
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
|
}
|
|
|
|
case "BF16":
|
|
u8s := make([]uint8, st.size)
|
|
if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
f32s = bfloat16.DecodeFloat32(u8s)
|
|
default:
|
|
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
|
}
|
|
|
|
if st.repacker != nil {
|
|
f32s, err = st.repacker(st.Name(), f32s, st.Shape())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
switch st.Kind() {
|
|
case tensorKindF32:
|
|
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
|
case tensorKindF16:
|
|
f16s := make([]uint16, len(f32s))
|
|
for i := range f32s {
|
|
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
|
}
|
|
|
|
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
|
default:
|
|
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
|
}
|
|
}
|