From c8cf0d94edeae0c71e3a0877895d9519b5d4d5e3 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sun, 28 Apr 2024 10:36:38 -0700 Subject: [PATCH] llama3 conversion --- convert/convert.go | 1 + convert/llama.go | 70 +++++++++++++++++++++++++++++++++++----------- llm/gguf.go | 1 + 3 files changed, 56 insertions(+), 16 deletions(-) diff --git a/convert/convert.go b/convert/convert.go index dbc26da1..899c8c44 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -93,6 +93,7 @@ type Vocab struct { Tokens []string Scores []float32 Types []int32 + Merges []string } func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) { diff --git a/convert/llama.go b/convert/llama.go index 5dfb8d7d..8cb162e7 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "log/slog" + "os" + "path/filepath" "regexp" "strings" @@ -105,12 +107,12 @@ func (m *LlamaModel) GetTensors() error { matches := re.FindAllStringSubmatch(l.Name, -1) if len(matches) > 0 { slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name)) - switch l.WriterTo.(type) { - case torchWriterTo: + switch m.Format.(type) { + case *TorchFormat: wt := l.WriterTo.(torchWriterTo) wt.handler = llamaTorchLayerHandler l.WriterTo = wt - case safetensorWriterTo: + case *SafetensorFormat: wt := l.WriterTo.(safetensorWriterTo) wt.handler = mistralLayerHandler l.WriterTo = wt @@ -123,18 +125,46 @@ func (m *LlamaModel) GetTensors() error { } func (m *LlamaModel) LoadVocab() error { - var v *Vocab - var err error - - slog.Debug("loading vocab") - v, err = LoadSentencePieceTokens(m.Path, m.Params) - if err != nil { - return err + v := &Vocab{ + Tokens: []string{}, + Types: []int32{}, + Merges: []string{}, } - slog.Debug("vocab loaded") + tokpath := filepath.Join(m.Path, "tokenizer.json") + slog.Debug(fmt.Sprintf("looking for %s", tokpath)) + if _, err := os.Stat(tokpath); !os.IsNotExist(err) { + t, err := newTokenizer(tokpath) + if err != nil { + return err + } + for _, tok := range t.Model.Tokens { + v.Tokens = append(v.Tokens, tok.Content) + var tokType int32 + switch { + case tok.Special: + tokType = 3 + case tok.UserDefined: + tokType = 4 + default: + tokType = 1 + } + v.Types = append(v.Types, tokType) + } + v.Merges = t.Model.Merges + } else { + slog.Debug("loading sentence piece vocab") + v, err = LoadSentencePieceTokens(m.Path, m.Params) + if err != nil { + return err + } + + slog.Debug("vocab loaded") + + } m.Vocab = v + return nil } @@ -147,22 +177,30 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error { "llama.embedding_length": uint32(m.Params.HiddenSize), "llama.block_count": uint32(m.Params.HiddenLayers), "llama.feed_forward_length": uint32(m.Params.IntermediateSize), + "llama.rope.freq_base": float32(m.Params.RopeFrequencyBase), "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads), "llama.attention.head_count": uint32(m.Params.AttentionHeads), "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads), "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS), - "general.file_type": uint32(1), - "tokenizer.ggml.model": "llama", + //"general.file_type": uint32(1), + "general.file_type": uint32(2), + //"tokenizer.ggml.model": "llama", + "tokenizer.ggml.model": "gpt2", "tokenizer.ggml.tokens": m.Vocab.Tokens, - "tokenizer.ggml.scores": m.Vocab.Scores, "tokenizer.ggml.token_type": m.Vocab.Types, "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID), "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID), "tokenizer.ggml.unknown_token_id": uint32(0), - "tokenizer.ggml.add_bos_token": true, - "tokenizer.ggml.add_eos_token": false, + //"tokenizer.ggml.add_bos_token": true, + //"tokenizer.ggml.add_eos_token": false, + } + + if len(m.Vocab.Merges) > 0 { + kv["tokenizer.ggml.merges"] = m.Vocab.Merges + } else { + kv["tokenizer.ggml.scores"] = m.Vocab.Scores } return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) diff --git a/llm/gguf.go b/llm/gguf.go index 5f6e8004..c3cc3d41 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -483,6 +483,7 @@ var ggufKVOrder = map[string][]string{ "tokenizer.ggml.model", "tokenizer.ggml.tokens", "tokenizer.ggml.scores", + "tokenizer.ggml.merges", "tokenizer.ggml.token_type", "tokenizer.ggml.bos_token_id", "tokenizer.ggml.eos_token_id",