diff --git a/convert/convert.go b/convert/convert.go index 9a05fb52..e9c2ef2d 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -37,6 +37,8 @@ type Params struct { Experts int `json:"num_local_experts"` ExpertsUsed int `json:"num_experts_per_tok"` + PreTokenizer string + ByteOrder } diff --git a/convert/llama.go b/convert/llama.go index 9fdcd02b..83d942cb 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -2,9 +2,9 @@ package convert import ( "encoding/binary" + "errors" "fmt" "io" - "log/slog" "os" "path/filepath" "regexp" @@ -134,44 +134,27 @@ func (m *LlamaModel) GetTensors() error { } func (m *LlamaModel) LoadVocab() error { - v := &Vocab{ - Tokens: []string{}, - Types: []int32{}, - Merges: []string{}, - } + v := &Vocab{} tokpath := filepath.Join(m.Path, "tokenizer.json") - slog.Debug(fmt.Sprintf("looking for %s", tokpath)) - if _, err := os.Stat(tokpath); !os.IsNotExist(err) { - t, err := newTokenizer(tokpath) - if err != nil { - return err - } - - for _, tok := range t.Model.Tokens { - v.Tokens = append(v.Tokens, tok.Content) - var tokType int32 - switch { - case tok.Special: - tokType = 3 - case tok.UserDefined: - tokType = 4 - default: - tokType = 1 - } - v.Types = append(v.Types, tokType) - } - v.Merges = t.Model.Merges - } else { - slog.Debug("loading sentence piece vocab") + pre, ts, merges, err := parseTokens(tokpath) + if errors.Is(err, os.ErrNotExist) { v, err = LoadSentencePieceTokens(m.Path, m.Params) if err != nil { return err } + } else if err != nil { + return err + } else { + for _, t := range ts { + v.Tokens = append(v.Tokens, t.Content) + v.Types = append(v.Types, t.Type()) + } - slog.Debug("vocab loaded") - + m.Params.PreTokenizer = pre + v.Merges = merges } + m.Vocab = v return nil @@ -194,6 +177,7 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error { "general.file_type": uint32(2), "tokenizer.ggml.model": "gpt2", + "tokenizer.ggml.pre": m.Params.PreTokenizer, "tokenizer.ggml.tokens": m.Vocab.Tokens, "tokenizer.ggml.token_type": m.Vocab.Types, diff --git a/convert/tokenizer.go b/convert/tokenizer.go index a7da81e6..a847a84c 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -1,15 +1,30 @@ package convert import ( + "cmp" + "crypto/sha256" "encoding/json" - "io/ioutil" + "fmt" + "log/slog" "os" + "slices" + + "golang.org/x/exp/maps" ) type Tokenizer struct { Version string `json:"version"` AddedTokens []Token `json:"added_tokens"` Model TokenizerModel `json:"model"` + + PreTokenizer struct { + PreTokenziers []struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + } `json:"pretokenizers"` + } `json:"pre_tokenizer"` } type TokenizerModel struct { @@ -26,47 +41,69 @@ type Token struct { UserDefined bool } -func (t *Tokenizer) getMaxID() int { - var maxID int - for _, v := range t.Model.Vocab { - maxID = max(maxID, v) +func (t *Token) Type() int32 { + switch { + case t.Special: + return 3 + case t.UserDefined: + return 4 + default: + return 1 } - - for _, v := range t.AddedTokens { - maxID = max(maxID, v.ID) - } - return maxID } -func newTokenizer(dirpath string) (*Tokenizer, error) { +func (t *Tokenizer) maxID() int { + return max( + slices.Max(maps.Values(t.Model.Vocab)), + slices.MaxFunc(t.AddedTokens, func(a, b Token) int { + return cmp.Compare(a.ID, b.ID) + }).ID, + ) +} + +func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) { f, err := os.Open(dirpath) if err != nil { panic(err) } defer f.Close() - data, err := ioutil.ReadAll(f) - if err != nil { - return nil, err + var t Tokenizer + if err := json.NewDecoder(f).Decode(&t); err != nil { + return "", nil, nil, err } - var tdata Tokenizer - - if err := json.Unmarshal(data, &tdata); err != nil { - return nil, err + tokens = make([]Token, t.maxID()+1) + for k, v := range t.Model.Vocab { + tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false} } - maxID := tdata.getMaxID() - tdata.Model.Tokens = make([]Token, maxID+1) - - for k, v := range tdata.Model.Vocab { - tdata.Model.Tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false} - } - - for _, v := range tdata.AddedTokens { + for _, v := range t.AddedTokens { v.UserDefined = true - tdata.Model.Tokens[v.ID] = v + tokens[v.ID] = v } - return &tdata, nil + sha256sum := sha256.New() + for _, pt := range t.PreTokenizer.PreTokenziers { + switch pt.Type { + case "Split": + if pt.Pattern.Regex != "" { + sha256sum.Write([]byte(pt.Pattern.Regex)) + } + } + } + + switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest { + case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f": + pre = "llama-bpe" + case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02": + pre = "deepseek-llm" + case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e": + pre = "deepseek-coder" + default: + slog.Warn("unknown pretokenizer, using default", "digest", digest) + pre = "default" + } + + return pre, tokens, t.Model.Merges, nil } diff --git a/llm/gguf.go b/llm/gguf.go index c3cc3d41..179b3255 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -480,6 +480,7 @@ var ggufKVOrder = map[string][]string{ "gemma.attention.key_length", "gemma.attention.value_length", "general.file_type", + "tokenizer.ggml.pre", "tokenizer.ggml.model", "tokenizer.ggml.tokens", "tokenizer.ggml.scores",