simplify safetensors reading

This commit is contained in:
Michael Yang 2024-05-20 09:47:01 -07:00
parent 3591bbe56f
commit 171eb040fc
6 changed files with 49 additions and 81 deletions

View file

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -14,7 +13,6 @@ import (
"strings" "strings"
"github.com/d4l3k/go-bfloat16" "github.com/d4l3k/go-bfloat16"
"github.com/mitchellh/mapstructure"
"github.com/x448/float16" "github.com/x448/float16"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -29,38 +27,36 @@ type safetensorWriterTo struct {
filename string filename string
dtype string dtype string
start, end, padding uint64 offset, size int64
repacker func(string, []float32, []uint64) ([]float32, error) repacker func(string, []float32, []uint64) ([]float32, error)
} }
type tensorMetaData struct { type safetensorMetadata struct {
Type string `mapstructure:"dtype"` Type string `json:"dtype"`
Shape []int `mapstructure:"shape"` Shape []uint64 `json:"shape"`
Offsets []int `mapstructure:"data_offsets"` Offsets []int64 `json:"data_offsets"`
} }
type SafetensorFormat struct{} type SafetensorFormat struct{}
func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) { func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
slog.Debug("getting tensor data")
var tensors []llm.Tensor var tensors []llm.Tensor
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors")) matches, err := filepath.Glob(filepath.Join(dirpath, "*.safetensors"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
var offset uint64 var offset uint64
for _, f := range files { for _, f := range matches {
var t []llm.Tensor var t []llm.Tensor
var err error var err error
t, offset, err = m.readTensors(f, offset, params) t, offset, err = m.readTensors(f, offset, params)
if err != nil { if err != nil {
slog.Error(err.Error())
return nil, err return nil, err
} }
tensors = append(tensors, t...) tensors = append(tensors, t...)
} }
slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
return tensors, nil return tensors, nil
} }
@ -71,76 +67,57 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
} }
defer f.Close() defer f.Close()
var jsonSize uint64 var n int64
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil { if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
return nil, 0, err return nil, 0, err
} }
buf := make([]byte, jsonSize) b := bytes.NewBuffer(make([]byte, 0, n))
_, err = io.ReadFull(f, buf) if _, err = io.CopyN(b, f, n); err != nil {
if err != nil {
return nil, 0, err return nil, 0, err
} }
d := json.NewDecoder(bytes.NewBuffer(buf)) var headers map[string]safetensorMetadata
d.UseNumber() if err := json.NewDecoder(b).Decode(&headers); err != nil {
var parsed map[string]interface{}
if err = d.Decode(&parsed); err != nil {
return nil, 0, err return nil, 0, err
} }
var keys []string var keys []string
for k := range parsed { for key := range headers {
keys = append(keys, k) if !strings.HasSuffix(key, "self_attn.rotary_embd.inv_freq") {
keys = append(keys, key)
}
} }
slices.Sort(keys) slices.Sort(keys)
slog.Info("converting layers")
var tensors []llm.Tensor var tensors []llm.Tensor
for _, k := range keys { for _, key := range keys {
if strings.HasSuffix(k, "self_attn.rotary_emb.inv_freq") { value := headers[key]
continue
}
vals := parsed[k].(map[string]interface{})
var data tensorMetaData
if err = mapstructure.Decode(vals, &data); err != nil {
slog.Error("couldn't decode properly")
return nil, 0, err
}
var size uint64
var kind uint32 var kind uint32
switch len(data.Shape) { switch len(value.Shape) {
case 0: case 0:
// metadata // valuedata
continue continue
case 1:
// convert to float32
kind = 0
size = uint64(data.Shape[0] * 4)
case 2: case 2:
// convert to float16
kind = 1 kind = 1
size = uint64(data.Shape[0] * data.Shape[1] * 2)
} }
ggufName, err := m.GetLayerName(k) name, err := m.GetLayerName(key)
if err != nil { if err != nil {
slog.Error(err.Error())
return nil, 0, err return nil, 0, err
} }
shape := []uint64{0, 0, 0, 0} shape := make([]uint64, len(value.Shape))
for i := range data.Shape { copy(shape, value.Shape)
shape[i] = uint64(data.Shape[i])
}
slog.Debug(fmt.Sprintf("'%45s': '%30s' %10d [%#v]", k, ggufName, size, data.Shape)) pad := func(s int64) int64 {
return 8 + n + s
}
t := llm.Tensor{ t := llm.Tensor{
Name: ggufName, Name: name,
Kind: kind, Kind: kind,
Offset: offset, Offset: offset,
Shape: shape[:], Shape: shape[:],
@ -151,19 +128,15 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
params: params, params: params,
bo: params.ByteOrder, bo: params.ByteOrder,
filename: fn, filename: fn,
dtype: data.Type, dtype: value.Type,
start: uint64(data.Offsets[0]), offset: pad(value.Offsets[0]),
end: uint64(data.Offsets[1]), size: pad(value.Offsets[1]) - pad(value.Offsets[0]),
padding: 8 + jsonSize,
} }
offset += size offset += t.Size()
tensors = append(tensors, t) tensors = append(tensors, t)
} }
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
slog.Debug(fmt.Sprintf("offset = %d", offset))
return tensors, offset, nil return tensors, offset, nil
} }
@ -176,9 +149,7 @@ func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
var params Params var params Params
d := json.NewDecoder(f) if err := json.NewDecoder(f).Decode(&params); err != nil {
err = d.Decode(&params)
if err != nil {
return nil, err return nil, err
} }
@ -233,34 +204,34 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
} }
defer f.Close() defer f.Close()
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil { if _, err = f.Seek(r.offset, io.SeekStart); err != nil {
return 0, err return 0, err
} }
var f32s []float32 var f32s []float32
switch r.dtype { switch r.dtype {
case "F32": case "F32":
f32s = make([]float32, (r.end-r.start)/4) f32s = make([]float32, r.size/4)
if err = binary.Read(f, r.bo, f32s); err != nil { if err = binary.Read(f, r.bo, f32s); err != nil {
return 0, err return 0, err
} }
case "F16": case "F16":
bts := make([]uint16, (r.end-r.start)/2) u16s := make([]uint16, r.size/2)
if err = binary.Read(f, r.bo, bts); err != nil { if err = binary.Read(f, r.bo, u16s); err != nil {
return 0, err return 0, err
} }
for _, b := range bts { for _, b := range u16s {
f32s = append(f32s, float16.Frombits(b).Float32()) f32s = append(f32s, float16.Frombits(b).Float32())
} }
case "BF16": case "BF16":
bts := make([]byte, r.end-r.start) u8s := make([]uint8, r.size)
if err = binary.Read(f, r.bo, bts); err != nil { if err = binary.Read(f, r.bo, u8s); err != nil {
return 0, err return 0, err
} }
f32s = bfloat16.DecodeFloat32(bts) f32s = bfloat16.DecodeFloat32(u8s)
default: default:
return 0, fmt.Errorf("unknown data type: %s", r.dtype) return 0, fmt.Errorf("unknown data type: %s", r.dtype)
} }

1
go.mod
View file

@ -8,7 +8,6 @@ require (
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/golang/protobuf v1.5.4 // indirect github.com/golang/protobuf v1.5.4 // indirect
github.com/google/uuid v1.1.2 github.com/google/uuid v1.1.2
github.com/mitchellh/mapstructure v1.5.0
github.com/olekukonko/tablewriter v0.0.5 github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0

2
go.sum
View file

@ -135,8 +135,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View file

@ -119,7 +119,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
t.Offset = uint64(offset) t.Offset = uint64(offset)
if _, err := rs.Seek(int64(t.size()), io.SeekCurrent); err != nil { if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
return err return err
} }

View file

@ -106,7 +106,7 @@ type Layer map[string]*Tensor
func (l Layer) size() (size uint64) { func (l Layer) size() (size uint64) {
for _, t := range l { for _, t := range l {
size += t.size() size += t.Size()
} }
return size return size
@ -185,7 +185,7 @@ func (t Tensor) parameters() uint64 {
return count return count
} }
func (t Tensor) size() uint64 { func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize() return t.parameters() * t.typeSize() / t.blockSize()
} }
@ -288,7 +288,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
// mixtral 8x22b // mixtral 8x22b
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32)) ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max( partialOffload = max(
3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV), 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch), 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
) )
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok { } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {

View file

@ -241,11 +241,11 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
} }
for _, tensor := range llm.tensors { for _, tensor := range llm.tensors {
if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil { if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
return err return err
} }
padding := llm.padding(int64(tensor.size()), int64(alignment)) padding := llm.padding(int64(tensor.Size()), int64(alignment))
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil { if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
return err return err
} }