llama3 conversion

This commit is contained in:
Patrick Devine 2024-04-28 10:36:38 -07:00 committed by Michael Yang
parent 4730762e5c
commit c8cf0d94ed
3 changed files with 56 additions and 16 deletions

View file

@ -93,6 +93,7 @@ type Vocab struct {
Tokens []string Tokens []string
Scores []float32 Scores []float32
Types []int32 Types []int32
Merges []string
} }
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) { func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {

View file

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
@ -105,12 +107,12 @@ func (m *LlamaModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1) matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 { if len(matches) > 0 {
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name)) slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
switch l.WriterTo.(type) { switch m.Format.(type) {
case torchWriterTo: case *TorchFormat:
wt := l.WriterTo.(torchWriterTo) wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaTorchLayerHandler wt.handler = llamaTorchLayerHandler
l.WriterTo = wt l.WriterTo = wt
case safetensorWriterTo: case *SafetensorFormat:
wt := l.WriterTo.(safetensorWriterTo) wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler wt.handler = mistralLayerHandler
l.WriterTo = wt l.WriterTo = wt
@ -123,10 +125,36 @@ func (m *LlamaModel) GetTensors() error {
} }
func (m *LlamaModel) LoadVocab() error { func (m *LlamaModel) LoadVocab() error {
var v *Vocab v := &Vocab{
var err error Tokens: []string{},
Types: []int32{},
Merges: []string{},
}
slog.Debug("loading vocab") tokpath := filepath.Join(m.Path, "tokenizer.json")
slog.Debug(fmt.Sprintf("looking for %s", tokpath))
if _, err := os.Stat(tokpath); !os.IsNotExist(err) {
t, err := newTokenizer(tokpath)
if err != nil {
return err
}
for _, tok := range t.Model.Tokens {
v.Tokens = append(v.Tokens, tok.Content)
var tokType int32
switch {
case tok.Special:
tokType = 3
case tok.UserDefined:
tokType = 4
default:
tokType = 1
}
v.Types = append(v.Types, tokType)
}
v.Merges = t.Model.Merges
} else {
slog.Debug("loading sentence piece vocab")
v, err = LoadSentencePieceTokens(m.Path, m.Params) v, err = LoadSentencePieceTokens(m.Path, m.Params)
if err != nil { if err != nil {
return err return err
@ -134,7 +162,9 @@ func (m *LlamaModel) LoadVocab() error {
slog.Debug("vocab loaded") slog.Debug("vocab loaded")
}
m.Vocab = v m.Vocab = v
return nil return nil
} }
@ -147,22 +177,30 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
"llama.embedding_length": uint32(m.Params.HiddenSize), "llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.block_count": uint32(m.Params.HiddenLayers), "llama.block_count": uint32(m.Params.HiddenLayers),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize), "llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads), "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"llama.attention.head_count": uint32(m.Params.AttentionHeads), "llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads), "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS), "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"general.file_type": uint32(1), //"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama", "general.file_type": uint32(2),
//"tokenizer.ggml.model": "llama",
"tokenizer.ggml.model": "gpt2",
"tokenizer.ggml.tokens": m.Vocab.Tokens, "tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types, "tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID), "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID), "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0), "tokenizer.ggml.unknown_token_id": uint32(0),
"tokenizer.ggml.add_bos_token": true, //"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false, //"tokenizer.ggml.add_eos_token": false,
}
if len(m.Vocab.Merges) > 0 {
kv["tokenizer.ggml.merges"] = m.Vocab.Merges
} else {
kv["tokenizer.ggml.scores"] = m.Vocab.Scores
} }
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)

View file

@ -483,6 +483,7 @@ var ggufKVOrder = map[string][]string{
"tokenizer.ggml.model", "tokenizer.ggml.model",
"tokenizer.ggml.tokens", "tokenizer.ggml.tokens",
"tokenizer.ggml.scores", "tokenizer.ggml.scores",
"tokenizer.ggml.merges",
"tokenizer.ggml.token_type", "tokenizer.ggml.token_type",
"tokenizer.ggml.bos_token_id", "tokenizer.ggml.bos_token_id",
"tokenizer.ggml.eos_token_id", "tokenizer.ggml.eos_token_id",