add mixtral 8x7b model conversion (#3859)

This commit is contained in:
Patrick Devine 2024-04-23 20:17:04 -07:00 committed by GitHub
parent 4dc4f1be34
commit ce8ce82567
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 138 additions and 25 deletions

View file

@ -31,6 +31,10 @@ type Params struct {
EoSTokenID int `json:"eos_token_id"` EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"` HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"` PaddingTokenID int `json:"pad_token_id"`
RopeFrequencyBase float64 `json:"rope_theta"`
Experts int `json:"num_local_experts"`
ExpertsUsed int `json:"num_experts_per_tok"`
ByteOrder ByteOrder
} }

96
convert/mixtral.go Normal file
View file

@ -0,0 +1,96 @@
package convert
import (
"os"
"regexp"
"github.com/ollama/ollama/llm"
)
type MixtralModel struct {
ModelData
}
func (m *MixtralModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *MixtralModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}
m.Vocab = v
return nil
}
func (m *MixtralModel) WriteGGUF() (string, error) {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"llama.expert_count": uint32(m.Params.Experts),
"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
}

View file

@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
} }
slices.Sort(keys) slices.Sort(keys)
slog.Info("converting layers") slog.Info("converting layers")
var tensors []llm.Tensor var tensors []llm.Tensor
@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
return nil, 0, err return nil, 0, err
} }
slog.Debug(fmt.Sprintf("metadata = %#v", data))
var size uint64 var size uint64
var kind uint32 var kind uint32
switch len(data.Shape) { switch len(data.Shape) {
@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
padding: 8 + jsonSize, padding: 8 + jsonSize,
} }
tensors = append(tensors, t)
offset += size offset += size
tensors = append(tensors, t)
} }
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors))) slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
slog.Debug(fmt.Sprintf("offset = %d", offset)) slog.Debug(fmt.Sprintf("offset = %d", offset))
return tensors, offset, nil return tensors, offset, nil
} }
@ -194,6 +194,10 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight", "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight", "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight", "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
} }
v, ok := directMap[n] v, ok := directMap[n]
@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
Format: m, Format: m,
}, },
}, nil }, nil
case "MixtralForCausalLM":
return &MixtralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "GemmaForCausalLM": case "GemmaForCausalLM":
return &GemmaModel{ return &GemmaModel{
ModelData{ ModelData{