2024-04-15 11:26:42 -07:00
|
|
|
package convert
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"encoding/binary"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
"regexp"
|
|
|
|
"slices"
|
2024-05-15 14:55:57 -07:00
|
|
|
"strings"
|
2024-04-15 11:26:42 -07:00
|
|
|
|
|
|
|
"github.com/d4l3k/go-bfloat16"
|
|
|
|
"github.com/x448/float16"
|
|
|
|
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
|
|
)
|
|
|
|
|
|
|
|
type safetensorWriterTo struct {
|
|
|
|
t *llm.Tensor
|
|
|
|
|
|
|
|
params *Params
|
|
|
|
bo ByteOrder
|
|
|
|
|
|
|
|
filename string
|
2024-05-17 12:11:49 -07:00
|
|
|
dtype string
|
2024-04-15 11:26:42 -07:00
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
offset, size int64
|
|
|
|
repacker func(string, []float32, []uint64) ([]float32, error)
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
type safetensorMetadata struct {
|
|
|
|
Type string `json:"dtype"`
|
|
|
|
Shape []uint64 `json:"shape"`
|
|
|
|
Offsets []int64 `json:"data_offsets"`
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
type SafetensorFormat struct{}
|
|
|
|
|
|
|
|
func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
|
|
|
var tensors []llm.Tensor
|
2024-05-20 09:47:01 -07:00
|
|
|
matches, err := filepath.Glob(filepath.Join(dirpath, "*.safetensors"))
|
2024-04-15 11:26:42 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
var offset uint64
|
2024-05-20 09:47:01 -07:00
|
|
|
for _, f := range matches {
|
2024-04-15 11:26:42 -07:00
|
|
|
var t []llm.Tensor
|
|
|
|
var err error
|
|
|
|
t, offset, err = m.readTensors(f, offset, params)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2024-05-20 09:47:01 -07:00
|
|
|
|
2024-04-15 11:26:42 -07:00
|
|
|
tensors = append(tensors, t...)
|
|
|
|
}
|
|
|
|
return tensors, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
|
|
|
f, err := os.Open(fn)
|
|
|
|
if err != nil {
|
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
defer f.Close()
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
var n int64
|
|
|
|
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
|
2024-04-15 11:26:42 -07:00
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
b := bytes.NewBuffer(make([]byte, 0, n))
|
|
|
|
if _, err = io.CopyN(b, f, n); err != nil {
|
2024-04-15 11:26:42 -07:00
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
var headers map[string]safetensorMetadata
|
|
|
|
if err := json.NewDecoder(b).Decode(&headers); err != nil {
|
2024-04-15 11:26:42 -07:00
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
|
|
|
|
var keys []string
|
2024-05-20 09:47:01 -07:00
|
|
|
for key := range headers {
|
|
|
|
if !strings.HasSuffix(key, "self_attn.rotary_embd.inv_freq") {
|
|
|
|
keys = append(keys, key)
|
|
|
|
}
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
slices.Sort(keys)
|
|
|
|
|
|
|
|
var tensors []llm.Tensor
|
2024-05-20 09:47:01 -07:00
|
|
|
for _, key := range keys {
|
|
|
|
value := headers[key]
|
2024-04-15 11:26:42 -07:00
|
|
|
|
|
|
|
var kind uint32
|
2024-05-20 09:47:01 -07:00
|
|
|
switch len(value.Shape) {
|
2024-04-15 11:26:42 -07:00
|
|
|
case 0:
|
2024-05-20 09:47:01 -07:00
|
|
|
// valuedata
|
2024-04-15 11:26:42 -07:00
|
|
|
continue
|
|
|
|
case 2:
|
|
|
|
kind = 1
|
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
name, err := m.GetLayerName(key)
|
2024-04-15 11:26:42 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, 0, err
|
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
shape := make([]uint64, len(value.Shape))
|
|
|
|
copy(shape, value.Shape)
|
2024-04-15 11:26:42 -07:00
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
pad := func(s int64) int64 {
|
|
|
|
return 8 + n + s
|
|
|
|
}
|
2024-05-08 16:07:46 -07:00
|
|
|
|
2024-04-15 11:26:42 -07:00
|
|
|
t := llm.Tensor{
|
2024-05-20 09:47:01 -07:00
|
|
|
Name: name,
|
2024-04-15 11:26:42 -07:00
|
|
|
Kind: kind,
|
|
|
|
Offset: offset,
|
2024-05-21 22:07:57 -07:00
|
|
|
Shape: shape,
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
t.WriterTo = safetensorWriterTo{
|
|
|
|
t: &t,
|
|
|
|
params: params,
|
|
|
|
bo: params.ByteOrder,
|
|
|
|
filename: fn,
|
2024-05-20 09:47:01 -07:00
|
|
|
dtype: value.Type,
|
|
|
|
offset: pad(value.Offsets[0]),
|
|
|
|
size: pad(value.Offsets[1]) - pad(value.Offsets[0]),
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
offset += t.Size()
|
2024-04-23 20:17:04 -07:00
|
|
|
tensors = append(tensors, t)
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
2024-04-23 20:17:04 -07:00
|
|
|
|
2024-04-15 11:26:42 -07:00
|
|
|
return tensors, offset, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
|
|
|
|
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
defer f.Close()
|
|
|
|
|
|
|
|
var params Params
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
if err := json.NewDecoder(f).Decode(¶ms); err != nil {
|
2024-04-15 11:26:42 -07:00
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
params.ByteOrder = binary.LittleEndian
|
|
|
|
return ¶ms, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
|
|
|
|
directMap := map[string]string{
|
|
|
|
"model.embed_tokens.weight": "token_embd.weight",
|
|
|
|
"lm_head.weight": "output.weight",
|
|
|
|
"model.norm.weight": "output_norm.weight",
|
|
|
|
}
|
|
|
|
|
|
|
|
tMap := map[string]string{
|
2024-04-23 20:17:04 -07:00
|
|
|
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
|
|
|
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
|
|
|
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
|
|
|
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
|
|
|
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
|
|
|
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
|
|
|
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
|
|
|
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
|
|
|
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
|
|
|
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
|
|
|
|
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
|
|
|
|
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
|
|
|
|
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
v, ok := directMap[n]
|
|
|
|
if ok {
|
|
|
|
return v, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// quick hack to rename the layers to gguf format
|
|
|
|
for k, v := range tMap {
|
|
|
|
re := regexp.MustCompile(k)
|
|
|
|
newName := re.ReplaceAllString(n, v)
|
|
|
|
if newName != n {
|
|
|
|
return newName, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
|
|
|
f, err := os.Open(r.filename)
|
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
defer f.Close()
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
if _, err = f.Seek(r.offset, io.SeekStart); err != nil {
|
2024-04-15 11:26:42 -07:00
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
|
2024-05-17 12:11:49 -07:00
|
|
|
var f32s []float32
|
|
|
|
switch r.dtype {
|
|
|
|
case "F32":
|
2024-05-20 09:47:01 -07:00
|
|
|
f32s = make([]float32, r.size/4)
|
2024-05-17 12:11:49 -07:00
|
|
|
if err = binary.Read(f, r.bo, f32s); err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
case "F16":
|
2024-05-20 09:47:01 -07:00
|
|
|
u16s := make([]uint16, r.size/2)
|
|
|
|
if err = binary.Read(f, r.bo, u16s); err != nil {
|
2024-05-17 12:11:49 -07:00
|
|
|
return 0, err
|
|
|
|
}
|
2024-04-15 11:26:42 -07:00
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
for _, b := range u16s {
|
2024-05-17 12:11:49 -07:00
|
|
|
f32s = append(f32s, float16.Frombits(b).Float32())
|
|
|
|
}
|
2024-04-15 11:26:42 -07:00
|
|
|
|
2024-05-17 12:11:49 -07:00
|
|
|
case "BF16":
|
2024-05-20 09:47:01 -07:00
|
|
|
u8s := make([]uint8, r.size)
|
|
|
|
if err = binary.Read(f, r.bo, u8s); err != nil {
|
2024-04-15 11:26:42 -07:00
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
|
2024-05-20 09:47:01 -07:00
|
|
|
f32s = bfloat16.DecodeFloat32(u8s)
|
2024-05-17 12:11:49 -07:00
|
|
|
default:
|
|
|
|
return 0, fmt.Errorf("unknown data type: %s", r.dtype)
|
|
|
|
}
|
2024-04-15 11:26:42 -07:00
|
|
|
|
2024-05-17 12:11:49 -07:00
|
|
|
if r.repacker != nil {
|
|
|
|
f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
|
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
2024-05-17 12:11:49 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
switch r.t.Kind {
|
|
|
|
case 0:
|
|
|
|
return 0, binary.Write(w, r.bo, f32s)
|
|
|
|
case 1:
|
|
|
|
f16s := make([]uint16, len(f32s))
|
|
|
|
for i := range f32s {
|
|
|
|
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
2024-05-17 12:11:49 -07:00
|
|
|
|
|
|
|
return 0, binary.Write(w, r.bo, f16s)
|
|
|
|
default:
|
|
|
|
return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
|
2024-04-15 11:26:42 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
|
|
|
switch len(params.Architectures) {
|
|
|
|
case 0:
|
|
|
|
return nil, fmt.Errorf("No architecture specified to convert")
|
|
|
|
case 1:
|
|
|
|
switch params.Architectures[0] {
|
2024-04-24 18:32:01 -07:00
|
|
|
case "LlamaForCausalLM":
|
|
|
|
return &LlamaModel{
|
|
|
|
ModelData{
|
|
|
|
Name: name,
|
|
|
|
Path: dirPath,
|
|
|
|
Params: params,
|
|
|
|
Format: m,
|
|
|
|
},
|
|
|
|
}, nil
|
2024-04-15 11:26:42 -07:00
|
|
|
case "MistralForCausalLM":
|
|
|
|
return &MistralModel{
|
|
|
|
ModelData{
|
|
|
|
Name: name,
|
|
|
|
Path: dirPath,
|
|
|
|
Params: params,
|
|
|
|
Format: m,
|
|
|
|
},
|
|
|
|
}, nil
|
2024-04-23 20:17:04 -07:00
|
|
|
case "MixtralForCausalLM":
|
|
|
|
return &MixtralModel{
|
|
|
|
ModelData{
|
|
|
|
Name: name,
|
|
|
|
Path: dirPath,
|
|
|
|
Params: params,
|
|
|
|
Format: m,
|
|
|
|
},
|
|
|
|
}, nil
|
2024-04-15 11:26:42 -07:00
|
|
|
case "GemmaForCausalLM":
|
|
|
|
return &GemmaModel{
|
|
|
|
ModelData{
|
|
|
|
Name: name,
|
|
|
|
Path: dirPath,
|
|
|
|
Params: params,
|
|
|
|
Format: m,
|
|
|
|
},
|
|
|
|
}, nil
|
|
|
|
default:
|
|
|
|
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("Unknown error")
|
|
|
|
}
|