fix conversion for f16 or f32 inputs

This commit is contained in:
Michael Yang 2024-05-17 12:11:49 -07:00
parent bbbd9f20f3
commit 34d5ef29b3
7 changed files with 152 additions and 294 deletions

View file

@ -1,14 +1,11 @@
package convert package convert
import ( import (
"encoding/binary"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os"
"strings" "strings"
"github.com/d4l3k/go-bfloat16"
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
"github.com/pdevine/tensor/native" "github.com/pdevine/tensor/native"
@ -19,49 +16,27 @@ type GemmaModel struct {
ModelData ModelData
} }
func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
data := make([]byte, r.end-r.start)
if err := binary.Read(f, r.bo, data); err != nil {
return err
}
tDataF32 := bfloat16.DecodeFloat32(data)
var err error
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
if err != nil {
return err
}
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return err
}
return nil
}
func addOnes(data []float32, vectorSize int) ([]float32, error) { func addOnes(data []float32, vectorSize int) ([]float32, error) {
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data)) n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, vectorSize) ones := tensor.Ones(tensor.Float32, vectorSize)
var err error n, err := n.Add(ones)
n, err = n.Add(ones)
if err != nil { if err != nil {
return []float32{}, err return nil, err
} }
newN, err := native.SelectF32(n, 0) ts, err := native.SelectF32(n, 0)
if err != nil { if err != nil {
return []float32{}, err return nil, err
} }
var fullTensor []float32 var f32s []float32
for _, v := range newN { for _, t := range ts {
fullTensor = append(fullTensor, v...) f32s = append(f32s, t...)
} }
return fullTensor, nil
return f32s, nil
} }
func (m *GemmaModel) GetTensors() error { func (m *GemmaModel) GetTensors() error {
@ -74,7 +49,7 @@ func (m *GemmaModel) GetTensors() error {
for _, l := range t { for _, l := range t {
if strings.HasSuffix(l.Name, "norm.weight") { if strings.HasSuffix(l.Name, "norm.weight") {
wt := l.WriterTo.(safetensorWriterTo) wt := l.WriterTo.(safetensorWriterTo)
wt.handler = gemmaLayerHandler wt.repacker = m.Repack
l.WriterTo = wt l.WriterTo = wt
} }
m.Tensors = append(m.Tensors, l) m.Tensors = append(m.Tensors, l)
@ -92,6 +67,10 @@ func (m *GemmaModel) LoadVocab() error {
return nil return nil
} }
func (m *GemmaModel) Repack(_ string, data []float32, shape []uint64) ([]float32, error) {
return addOnes(data, int(shape[0]))
}
func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error { func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "gemma", "general.architecture": "gemma",

View file

@ -1,7 +1,7 @@
package convert package convert
import ( import (
"encoding/binary" "cmp"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -10,10 +10,8 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/nlpodyssey/gopickle/pytorch"
"github.com/pdevine/tensor" "github.com/pdevine/tensor"
"github.com/pdevine/tensor/native" "github.com/pdevine/tensor/native"
"github.com/x448/float16"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
@ -22,83 +20,6 @@ type LlamaModel struct {
ModelData ModelData
} }
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
var tData []uint16
switch r.storage.(type) {
case *pytorch.HalfStorage:
data := r.storage.(*pytorch.HalfStorage).Data
tData = make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
case *pytorch.BFloat16Storage:
data := r.storage.(*pytorch.BFloat16Storage).Data
tData = make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
default:
return fmt.Errorf("unknown storage type for torch")
}
var err error
var heads uint32
if strings.Contains(r.t.Name, "attn_q") {
heads = uint32(r.params.AttentionHeads)
} else if strings.Contains(r.t.Name, "attn_k") {
heads = uint32(r.params.KeyValHeads)
if heads == 0 {
heads = uint32(r.params.AttentionHeads)
}
} else {
return fmt.Errorf("unknown layer type")
}
tData, err = llamaRepack(tData, int(heads), r.t.Shape)
if err != nil {
return err
}
if err = binary.Write(w, r.bo, tData); err != nil {
return err
}
return nil
}
func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(origShape...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}
var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}
func (m *LlamaModel) GetTensors() error { func (m *LlamaModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params) t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil { if err != nil {
@ -117,11 +38,11 @@ func (m *LlamaModel) GetTensors() error {
switch m.Format.(type) { switch m.Format.(type) {
case *TorchFormat: case *TorchFormat:
wt := l.WriterTo.(torchWriterTo) wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaTorchLayerHandler wt.repacker = m.Repack
l.WriterTo = wt l.WriterTo = wt
case *SafetensorFormat: case *SafetensorFormat:
wt := l.WriterTo.(safetensorWriterTo) wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler wt.repacker = m.Repack
l.WriterTo = wt l.WriterTo = wt
} }
} }
@ -184,3 +105,54 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
} }
func (m *LlamaModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
return llamaRepack(name, m.Params, data, shape)
}
func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
if dim != 0 {
dims = append(dims, int(dim))
}
}
var heads int
if strings.HasSuffix(name, "attn_q.weight") {
heads = params.AttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor name: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{heads, 2, dims[0] / heads / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View file

@ -1,17 +1,8 @@
package convert package convert
import ( import (
"encoding/binary"
"fmt"
"io" "io"
"os"
"regexp" "regexp"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
@ -20,82 +11,6 @@ type MistralModel struct {
ModelData ModelData
} }
func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
layerSize := r.end - r.start
var err error
tData := make([]uint16, layerSize/2)
if err = binary.Read(f, r.bo, tData); err != nil {
return err
}
var heads uint32
if strings.Contains(r.t.Name, "attn_q") {
heads = uint32(r.params.AttentionHeads)
} else if strings.Contains(r.t.Name, "attn_k") {
heads = uint32(r.params.KeyValHeads)
if heads == 0 {
heads = uint32(r.params.AttentionHeads)
}
} else {
return fmt.Errorf("unknown layer type")
}
tData, err = repack(tData, int(heads), r.t.Shape)
if err != nil {
return err
}
var buf []byte
for _, n := range tData {
buf = r.bo.AppendUint16(buf, n)
}
tempBuf := make([]uint16, len(tData))
tDataF32 := bfloat16.DecodeFloat32(buf)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err = binary.Write(w, r.bo, tempBuf); err != nil {
return err
}
return nil
}
func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(origShape...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}
var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}
func (m *MistralModel) GetTensors() error { func (m *MistralModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params) t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil { if err != nil {
@ -112,7 +27,7 @@ func (m *MistralModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1) matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 { if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo) wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler wt.repacker = m.Repack
l.WriterTo = wt l.WriterTo = wt
} }
m.Tensors = append(m.Tensors, l) m.Tensors = append(m.Tensors, l)
@ -158,3 +73,7 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
} }
func (m *MistralModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
return llamaRepack(name, m.Params, data, shape)
}

View file

@ -27,7 +27,7 @@ func (m *MixtralModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1) matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 { if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo) wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler wt.repacker = m.Repack
l.WriterTo = wt l.WriterTo = wt
} }
m.Tensors = append(m.Tensors, l) m.Tensors = append(m.Tensors, l)
@ -81,3 +81,7 @@ func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
} }
func (m *MixtralModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
return llamaRepack(name, m.Params, data, shape)
}

View file

@ -27,9 +27,10 @@ type safetensorWriterTo struct {
bo ByteOrder bo ByteOrder
filename string filename string
dtype string
start, end, padding uint64 start, end, padding uint64
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error repacker func(string, []float32, []uint64) ([]float32, error)
} }
type tensorMetaData struct { type tensorMetaData struct {
@ -150,6 +151,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
params: params, params: params,
bo: params.ByteOrder, bo: params.ByteOrder,
filename: fn, filename: fn,
dtype: data.Type,
start: uint64(data.Offsets[0]), start: uint64(data.Offsets[0]),
end: uint64(data.Offsets[1]), end: uint64(data.Offsets[1]),
padding: 8 + jsonSize, padding: 8 + jsonSize,
@ -235,51 +237,54 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
return 0, err return 0, err
} }
// use the handler if one is present var f32s []float32
if r.handler != nil { switch r.dtype {
return 0, r.handler(w, r, f) case "F32":
} f32s = make([]float32, (r.end-r.start)/4)
if err = binary.Read(f, r.bo, f32s); err != nil {
remaining := r.end - r.start return 0, err
}
bufSize := uint64(10240) case "F16":
var finished bool bts := make([]uint16, (r.end-r.start)/2)
for { if err = binary.Read(f, r.bo, bts); err != nil {
data := make([]byte, min(bufSize, remaining))
b, err := io.ReadFull(f, data)
remaining -= uint64(b)
if err == io.EOF || remaining <= 0 {
finished = true
} else if err != nil {
return 0, err return 0, err
} }
// convert bfloat16 -> ieee float32 for _, b := range bts {
tDataF32 := bfloat16.DecodeFloat32(data) f32s = append(f32s, float16.Frombits(b).Float32())
switch r.t.Kind {
case 0:
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return 0, err
}
case 1:
// convert float32 -> float16
tempBuf := make([]uint16, len(data)/2)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err := binary.Write(w, r.bo, tempBuf); err != nil {
return 0, err
}
} }
if finished {
break case "BF16":
bts := make([]byte, r.end-r.start)
if err = binary.Read(f, r.bo, bts); err != nil {
return 0, err
}
f32s = bfloat16.DecodeFloat32(bts)
default:
return 0, fmt.Errorf("unknown data type: %s", r.dtype)
}
if r.repacker != nil {
f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
if err != nil {
return 0, err
} }
} }
return 0, nil
switch r.t.Kind {
case 0:
return 0, binary.Write(w, r.bo, f32s)
case 1:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, r.bo, f16s)
default:
return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
}
} }
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) { func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {

View file

@ -24,8 +24,8 @@ type torchWriterTo struct {
params *Params params *Params
bo ByteOrder bo ByteOrder
storage pytorch.StorageInterface storage pytorch.StorageInterface
handler func(w io.Writer, r torchWriterTo) error repacker func(string, []float32, []uint64) ([]float32, error)
} }
type TorchFormat struct{} type TorchFormat struct{}
@ -230,59 +230,38 @@ func (m *TorchFormat) GetLayerName(n string) (string, error) {
} }
func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) { func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
// use the handler if one is present var f32s []float32
if r.handler != nil { switch s := r.storage.(type) {
return 0, r.handler(w, r)
}
switch storage := r.storage.(type) {
case *pytorch.FloatStorage: case *pytorch.FloatStorage:
slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name)) f32s = s.Data
return 0, nil
case *pytorch.HalfStorage: case *pytorch.HalfStorage:
switch r.t.Kind { f32s = s.Data
case 0:
data := r.storage.(*pytorch.HalfStorage).Data
slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data)))
if err := binary.Write(w, r.bo, data); err != nil {
return 0, err
}
case 1:
data := r.storage.(*pytorch.HalfStorage).Data
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData)))
if err := binary.Write(w, r.bo, tData); err != nil {
return 0, err
}
}
case *pytorch.BFloat16Storage: case *pytorch.BFloat16Storage:
data := r.storage.(*pytorch.BFloat16Storage).Data f32s = s.Data
switch r.t.Kind {
case 0:
if err = binary.Write(w, r.bo, data); err != nil {
return 0, err
}
case 1:
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
if err = binary.Write(w, r.bo, tData); err != nil {
return 0, err
}
default:
return 0, fmt.Errorf("unknown storage kind: %d", r.t.Kind)
}
default: default:
return 0, fmt.Errorf("unknown storage type: %T", storage) return 0, fmt.Errorf("unknown data type: %T", s)
} }
return 0, nil if r.repacker != nil {
f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
if err != nil {
return 0, err
}
}
switch r.t.Kind {
case 0:
return 0, binary.Write(w, r.bo, f32s)
case 1:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, r.bo, f16s)
default:
return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
}
} }
func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) { func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {

2
go.mod
View file

@ -4,7 +4,6 @@ go 1.22.0
require ( require (
github.com/containerd/console v1.0.3 github.com/containerd/console v1.0.3
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/emirpasic/gods v1.18.1 github.com/emirpasic/gods v1.18.1
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/golang/protobuf v1.5.4 // indirect github.com/golang/protobuf v1.5.4 // indirect
@ -18,6 +17,7 @@ require (
) )
require ( require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c