llama3 conversion
This commit is contained in:
parent
4730762e5c
commit
c8cf0d94ed
3 changed files with 56 additions and 16 deletions
|
@ -93,6 +93,7 @@ type Vocab struct {
|
||||||
Tokens []string
|
Tokens []string
|
||||||
Scores []float32
|
Scores []float32
|
||||||
Types []int32
|
Types []int32
|
||||||
|
Merges []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -105,12 +107,12 @@ func (m *LlamaModel) GetTensors() error {
|
||||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
||||||
switch l.WriterTo.(type) {
|
switch m.Format.(type) {
|
||||||
case torchWriterTo:
|
case *TorchFormat:
|
||||||
wt := l.WriterTo.(torchWriterTo)
|
wt := l.WriterTo.(torchWriterTo)
|
||||||
wt.handler = llamaTorchLayerHandler
|
wt.handler = llamaTorchLayerHandler
|
||||||
l.WriterTo = wt
|
l.WriterTo = wt
|
||||||
case safetensorWriterTo:
|
case *SafetensorFormat:
|
||||||
wt := l.WriterTo.(safetensorWriterTo)
|
wt := l.WriterTo.(safetensorWriterTo)
|
||||||
wt.handler = mistralLayerHandler
|
wt.handler = mistralLayerHandler
|
||||||
l.WriterTo = wt
|
l.WriterTo = wt
|
||||||
|
@ -123,10 +125,36 @@ func (m *LlamaModel) GetTensors() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *LlamaModel) LoadVocab() error {
|
func (m *LlamaModel) LoadVocab() error {
|
||||||
var v *Vocab
|
v := &Vocab{
|
||||||
var err error
|
Tokens: []string{},
|
||||||
|
Types: []int32{},
|
||||||
|
Merges: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug("loading vocab")
|
tokpath := filepath.Join(m.Path, "tokenizer.json")
|
||||||
|
slog.Debug(fmt.Sprintf("looking for %s", tokpath))
|
||||||
|
if _, err := os.Stat(tokpath); !os.IsNotExist(err) {
|
||||||
|
t, err := newTokenizer(tokpath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tok := range t.Model.Tokens {
|
||||||
|
v.Tokens = append(v.Tokens, tok.Content)
|
||||||
|
var tokType int32
|
||||||
|
switch {
|
||||||
|
case tok.Special:
|
||||||
|
tokType = 3
|
||||||
|
case tok.UserDefined:
|
||||||
|
tokType = 4
|
||||||
|
default:
|
||||||
|
tokType = 1
|
||||||
|
}
|
||||||
|
v.Types = append(v.Types, tokType)
|
||||||
|
}
|
||||||
|
v.Merges = t.Model.Merges
|
||||||
|
} else {
|
||||||
|
slog.Debug("loading sentence piece vocab")
|
||||||
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -134,7 +162,9 @@ func (m *LlamaModel) LoadVocab() error {
|
||||||
|
|
||||||
slog.Debug("vocab loaded")
|
slog.Debug("vocab loaded")
|
||||||
|
|
||||||
|
}
|
||||||
m.Vocab = v
|
m.Vocab = v
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,22 +177,30 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||||
|
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
|
||||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||||
"general.file_type": uint32(1),
|
//"general.file_type": uint32(1),
|
||||||
"tokenizer.ggml.model": "llama",
|
"general.file_type": uint32(2),
|
||||||
|
//"tokenizer.ggml.model": "llama",
|
||||||
|
"tokenizer.ggml.model": "gpt2",
|
||||||
|
|
||||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
|
||||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||||
|
|
||||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||||
"tokenizer.ggml.add_bos_token": true,
|
//"tokenizer.ggml.add_bos_token": true,
|
||||||
"tokenizer.ggml.add_eos_token": false,
|
//"tokenizer.ggml.add_eos_token": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.Vocab.Merges) > 0 {
|
||||||
|
kv["tokenizer.ggml.merges"] = m.Vocab.Merges
|
||||||
|
} else {
|
||||||
|
kv["tokenizer.ggml.scores"] = m.Vocab.Scores
|
||||||
}
|
}
|
||||||
|
|
||||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||||
|
|
|
@ -483,6 +483,7 @@ var ggufKVOrder = map[string][]string{
|
||||||
"tokenizer.ggml.model",
|
"tokenizer.ggml.model",
|
||||||
"tokenizer.ggml.tokens",
|
"tokenizer.ggml.tokens",
|
||||||
"tokenizer.ggml.scores",
|
"tokenizer.ggml.scores",
|
||||||
|
"tokenizer.ggml.merges",
|
||||||
"tokenizer.ggml.token_type",
|
"tokenizer.ggml.token_type",
|
||||||
"tokenizer.ggml.bos_token_id",
|
"tokenizer.ggml.bos_token_id",
|
||||||
"tokenizer.ggml.eos_token_id",
|
"tokenizer.ggml.eos_token_id",
|
||||||
|
|
Loading…
Reference in a new issue