Simplify model conversion (#3422)

This commit is contained in:
Patrick Devine 2024-04-01 16:14:53 -07:00 committed by GitHub
parent d6dd2ff839
commit 3b6a9154dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 366 additions and 251 deletions

View file

@ -12,12 +12,9 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"slices" "slices"
"strings"
"github.com/d4l3k/go-bfloat16" "github.com/d4l3k/go-bfloat16"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16" "github.com/x448/float16"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -55,6 +52,20 @@ type MetaData struct {
Offsets []int `mapstructure:"data_offsets"` Offsets []int `mapstructure:"data_offsets"`
} }
type ModelArch interface {
GetTensors() error
LoadVocab() error
WriteGGUF() (string, error)
}
type ModelData struct {
Path string
Name string
Params *Params
Vocab *Vocab
Tensors []llm.Tensor
}
func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) { func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
f, err := os.Open(fn) f, err := os.Open(fn)
if err != nil { if err != nil {
@ -132,15 +143,13 @@ func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, ui
} }
t.WriterTo = safetensorWriterTo{ t.WriterTo = safetensorWriterTo{
t: &t, t: &t,
params: params, params: params,
bo: params.ByteOrder, bo: params.ByteOrder,
headCount: uint32(params.AttentionHeads), filename: fn,
headCountKV: uint32(params.KeyValHeads), start: uint64(data.Offsets[0]),
filename: fn, end: uint64(data.Offsets[1]),
start: uint64(data.Offsets[0]), padding: 8 + jsonSize,
end: uint64(data.Offsets[1]),
padding: 8 + jsonSize,
} }
slog.Debug(fmt.Sprintf("%v", t)) slog.Debug(fmt.Sprintf("%v", t))
@ -198,7 +207,7 @@ type Vocab struct {
Types []int32 Types []int32
} }
func LoadTokens(dirpath string, params *Params) (*Vocab, error) { func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model"))) slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model")) in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
if err != nil { if err != nil {
@ -278,8 +287,8 @@ func LoadTokens(dirpath string, params *Params) (*Vocab, error) {
} }
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens))) slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
if params.VocabSize > len(v.Tokens) { if vocabSize > len(v.Tokens) {
missingTokens := params.VocabSize - len(v.Tokens) missingTokens := vocabSize - len(v.Tokens)
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens)) slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
for cnt := 0; cnt < missingTokens; cnt++ { for cnt := 0; cnt < missingTokens; cnt++ {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1)) v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
@ -327,77 +336,16 @@ func GetTensorName(n string) (string, error) {
type safetensorWriterTo struct { type safetensorWriterTo struct {
t *llm.Tensor t *llm.Tensor
params *Params params *Params
bo ByteOrder bo ByteOrder
headCount uint32
headCountKV uint32
filename string filename string
start, end, padding uint64 start, end, padding uint64
} handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
func (r safetensorWriterTo) addOnes(data []float32) ([]float32, error) {
n := tensor.New(tensor.WithShape(int(r.t.Shape[0])), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, int(r.t.Shape[0]))
var err error
n, err = n.Add(ones)
if err != nil {
return []float32{}, err
}
newN, err := native.SelectF32(n, 0)
if err != nil {
return []float32{}, err
}
var fullTensor []float32
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}
func (r safetensorWriterTo) repack(data []uint16, heads int) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(r.t.Shape[0]), int(r.t.Shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(origShape...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}
var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
} }
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) { func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
arch, err := getArchFromParams(r.params)
if err != nil {
return 0, err
}
f, err := os.Open(r.filename) f, err := os.Open(r.filename)
if err != nil { if err != nil {
return 0, err return 0, err
@ -408,83 +356,9 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
return 0, err return 0, err
} }
switch arch { // use the handler if one is present
case "llama": if r.handler != nil {
return 0, r.handler(w, r, f)
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return 0, err
}
matches := re.FindAllStringSubmatch(r.t.Name, -1)
if len(matches) > 0 {
layerSize := r.end - r.start
var err error
tData := make([]uint16, layerSize/2)
if err = binary.Read(f, r.bo, tData); err != nil {
return 0, err
}
layerType := matches[0][re.SubexpIndex("layer")]
var heads uint32
switch layerType {
case "q":
heads = r.headCount
case "k":
heads = r.headCountKV
if heads == 0 {
heads = r.headCount
}
}
tData, err = r.repack(tData, int(heads))
if err != nil {
return 0, err
}
var buf []byte
for _, n := range tData {
buf = r.bo.AppendUint16(buf, n)
}
tempBuf := make([]uint16, len(tData))
tDataF32 := bfloat16.DecodeFloat32(buf)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err = binary.Write(w, r.bo, tempBuf); err != nil {
return 0, err
}
return 0, nil
}
case "gemma":
if strings.HasSuffix(r.t.Name, "norm.weight") {
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
data := make([]byte, r.end-r.start)
if err = binary.Read(f, r.bo, data); err != nil {
return 0, err
}
tDataF32 := bfloat16.DecodeFloat32(data)
var err error
tDataF32, err = r.addOnes(tDataF32)
if err != nil {
return 0, err
}
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return 0, err
}
return 0, nil
}
} }
remaining := r.end - r.start remaining := r.end - r.start
@ -529,93 +403,32 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
return 0, nil return 0, nil
} }
func getArchFromParams(params *Params) (string, error) { func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
var arch string
switch len(params.Architectures) { switch len(params.Architectures) {
case 0: case 0:
return "", fmt.Errorf("No architecture specified to convert") return nil, fmt.Errorf("No architecture specified to convert")
case 1: case 1:
switch params.Architectures[0] { switch params.Architectures[0] {
case "MistralForCausalLM": case "MistralForCausalLM":
arch = "llama" return &MistralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
},
}, nil
case "GemmaForCausalLM": case "GemmaForCausalLM":
arch = "gemma" return &GemmaModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
},
}, nil
default: default:
return "", fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0]) return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
} }
default:
return "", fmt.Errorf("Multimodal models are not yet supported")
} }
return arch, nil return nil, fmt.Errorf("Unknown error")
}
func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
arch, err := getArchFromParams(params)
if err != nil {
return "", err
}
kv := llm.KV{
"general.architecture": arch,
"general.name": name,
}
switch arch {
case "llama":
kv["llama.context_length"] = uint32(params.ContextSize)
kv["llama.embedding_length"] = uint32(params.HiddenSize)
kv["llama.block_count"] = uint32(params.HiddenLayers)
kv["llama.feed_forward_length"] = uint32(params.IntermediateSize)
kv["llama.rope.dimension_count"] = uint32(params.HiddenSize / params.AttentionHeads)
slog.Debug(fmt.Sprintf("rope dim count = %d", kv["llama.rope.dimension_count"]))
kv["llama.attention.head_count"] = uint32(params.AttentionHeads)
kv["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
kv["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
kv["llama.rope.freq_base"] = float32(params.RopeFreqBase)
case "gemma":
kv["gemma.context_length"] = uint32(params.ContextSize)
kv["gemma.embedding_length"] = uint32(params.HiddenSize)
kv["gemma.block_count"] = uint32(params.HiddenLayers)
kv["gemma.feed_forward_length"] = uint32(params.IntermediateSize)
kv["gemma.attention.head_count"] = uint32(params.AttentionHeads)
kv["gemma.attention.head_count_kv"] = uint32(params.KeyValHeads)
kv["gemma.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
kv["gemma.attention.key_length"] = uint32(params.HeadDimension)
kv["gemma.attention.value_length"] = uint32(params.HeadDimension)
}
kv["general.file_type"] = uint32(1)
kv["tokenizer.ggml.model"] = "llama"
kv["tokenizer.ggml.tokens"] = vocab.Tokens
kv["tokenizer.ggml.scores"] = vocab.Scores
kv["tokenizer.ggml.token_type"] = vocab.Types
kv["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
kv["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
switch arch {
case "llama":
kv["tokenizer.ggml.unknown_token_id"] = uint32(0)
case "gemma":
kv["tokenizer.ggml.padding_token_id"] = uint32(params.PaddingTokenID)
kv["tokenizer.ggml.unknown_token_id"] = uint32(3)
}
kv["tokenizer.ggml.add_bos_token"] = true
kv["tokenizer.ggml.add_eos_token"] = false
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
m := llm.NewGGUFV3(params.ByteOrder)
if err := m.Encode(f, kv, tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

136
convert/gemma.go Normal file
View file

@ -0,0 +1,136 @@
package convert
import (
"encoding/binary"
"fmt"
"io"
"log/slog"
"os"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type GemmaModel struct {
ModelData
}
func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
data := make([]byte, r.end-r.start)
if err := binary.Read(f, r.bo, data); err != nil {
return err
}
tDataF32 := bfloat16.DecodeFloat32(data)
var err error
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
if err != nil {
return err
}
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return err
}
return nil
}
func addOnes(data []float32, vectorSize int) ([]float32, error) {
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, vectorSize)
var err error
n, err = n.Add(ones)
if err != nil {
return []float32{}, err
}
newN, err := native.SelectF32(n, 0)
if err != nil {
return []float32{}, err
}
var fullTensor []float32
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}
func (m *GemmaModel) GetTensors() error {
t, err := GetSafeTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
for _, l := range t {
if strings.HasSuffix(l.Name, "norm.weight") {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = gemmaLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *GemmaModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
if err != nil {
return err
}
m.Vocab = v
return nil
}
func (m *GemmaModel) WriteGGUF() (string, error) {
kv := llm.KV{
"general.architecture": "gemma",
"general.name": m.Name,
"gemma.context_length": uint32(m.Params.ContextSize),
"gemma.embedding_length": uint32(m.Params.HiddenSize),
"gemma.block_count": uint32(m.Params.HiddenLayers),
"gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
"gemma.attention.head_count": uint32(m.Params.AttentionHeads),
"gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"gemma.attention.key_length": uint32(m.Params.HeadDimension),
"gemma.attention.value_length": uint32(m.Params.HeadDimension),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
"tokenizer.ggml.unknown_token_id": uint32(3),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
}

174
convert/mistral.go Normal file
View file

@ -0,0 +1,174 @@
package convert
import (
"encoding/binary"
"fmt"
"io"
"os"
"regexp"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"
"github.com/ollama/ollama/llm"
)
type MistralModel struct {
ModelData
}
func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
layerSize := r.end - r.start
var err error
tData := make([]uint16, layerSize/2)
if err = binary.Read(f, r.bo, tData); err != nil {
return err
}
var heads uint32
if strings.Contains(r.t.Name, "attn_q") {
heads = uint32(r.params.AttentionHeads)
} else if strings.Contains(r.t.Name, "attn_k") {
heads = uint32(r.params.KeyValHeads)
if heads == 0 {
heads = uint32(r.params.AttentionHeads)
}
} else {
return fmt.Errorf("unknown layer type")
}
tData, err = repack(tData, int(heads), r.t.Shape)
if err != nil {
return err
}
var buf []byte
for _, n := range tData {
buf = r.bo.AppendUint16(buf, n)
}
tempBuf := make([]uint16, len(tData))
tDataF32 := bfloat16.DecodeFloat32(buf)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err = binary.Write(w, r.bo, tempBuf); err != nil {
return err
}
return nil
}
func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(origShape...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}
var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}
func (m *MistralModel) GetTensors() error {
t, err := GetSafeTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *MistralModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
if err != nil {
return err
}
m.Vocab = v
return nil
}
func (m *MistralModel) WriteGGUF() (string, error) {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"llama.rope.freq_base": float32(m.Params.RopeFreqBase),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
"tokenizer.ggml.unknown_token_id": uint32(0),
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
}

View file

@ -654,30 +654,22 @@ func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (
return "", err return "", err
} }
SupportedArchs := []string{ mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
"MistralForCausalLM",
"GemmaForCausalLM",
}
for _, arch := range params.Architectures {
if !slices.Contains(SupportedArchs, arch) {
return "", fmt.Errorf("this safetensors model is not yet supported")
}
}
fn(api.ProgressResponse{Status: "processing safetensors"})
t, err := convert.GetSafeTensors(tempDir, params)
if err != nil { if err != nil {
return "", err return "", err
} }
vocab, err := convert.LoadTokens(tempDir, params) fn(api.ProgressResponse{Status: "processing safetensors"})
if err != nil { if err := mArch.GetTensors(); err != nil {
return "", err
}
if err := mArch.LoadVocab(); err != nil {
return "", err return "", err
} }
fn(api.ProgressResponse{Status: "converting model"}) fn(api.ProgressResponse{Status: "converting model"})
path, err = convert.WriteGGUF(name, t, params, vocab) path, err = mArch.WriteGGUF()
if err != nil { if err != nil {
return "", err return "", err
} }