bpe pretokenizer

This commit is contained in:
Michael Yang 2024-05-15 11:53:14 -07:00
parent 2d315ba9a9
commit 547132e820
4 changed files with 83 additions and 59 deletions

View file

@ -37,6 +37,8 @@ type Params struct {
Experts int `json:"num_local_experts"` Experts int `json:"num_local_experts"`
ExpertsUsed int `json:"num_experts_per_tok"` ExpertsUsed int `json:"num_experts_per_tok"`
PreTokenizer string
ByteOrder ByteOrder
} }

View file

@ -2,9 +2,9 @@ package convert
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -134,44 +134,27 @@ func (m *LlamaModel) GetTensors() error {
} }
func (m *LlamaModel) LoadVocab() error { func (m *LlamaModel) LoadVocab() error {
v := &Vocab{ v := &Vocab{}
Tokens: []string{},
Types: []int32{},
Merges: []string{},
}
tokpath := filepath.Join(m.Path, "tokenizer.json") tokpath := filepath.Join(m.Path, "tokenizer.json")
slog.Debug(fmt.Sprintf("looking for %s", tokpath)) pre, ts, merges, err := parseTokens(tokpath)
if _, err := os.Stat(tokpath); !os.IsNotExist(err) { if errors.Is(err, os.ErrNotExist) {
t, err := newTokenizer(tokpath)
if err != nil {
return err
}
for _, tok := range t.Model.Tokens {
v.Tokens = append(v.Tokens, tok.Content)
var tokType int32
switch {
case tok.Special:
tokType = 3
case tok.UserDefined:
tokType = 4
default:
tokType = 1
}
v.Types = append(v.Types, tokType)
}
v.Merges = t.Model.Merges
} else {
slog.Debug("loading sentence piece vocab")
v, err = LoadSentencePieceTokens(m.Path, m.Params) v, err = LoadSentencePieceTokens(m.Path, m.Params)
if err != nil { if err != nil {
return err return err
} }
} else if err != nil {
slog.Debug("vocab loaded") return err
} else {
for _, t := range ts {
v.Tokens = append(v.Tokens, t.Content)
v.Types = append(v.Types, t.Type())
} }
m.Params.PreTokenizer = pre
v.Merges = merges
}
m.Vocab = v m.Vocab = v
return nil return nil
@ -194,6 +177,7 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
"general.file_type": uint32(2), "general.file_type": uint32(2),
"tokenizer.ggml.model": "gpt2", "tokenizer.ggml.model": "gpt2",
"tokenizer.ggml.pre": m.Params.PreTokenizer,
"tokenizer.ggml.tokens": m.Vocab.Tokens, "tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.token_type": m.Vocab.Types, "tokenizer.ggml.token_type": m.Vocab.Types,

View file

@ -1,15 +1,30 @@
package convert package convert
import ( import (
"cmp"
"crypto/sha256"
"encoding/json" "encoding/json"
"io/ioutil" "fmt"
"log/slog"
"os" "os"
"slices"
"golang.org/x/exp/maps"
) )
type Tokenizer struct { type Tokenizer struct {
Version string `json:"version"` Version string `json:"version"`
AddedTokens []Token `json:"added_tokens"` AddedTokens []Token `json:"added_tokens"`
Model TokenizerModel `json:"model"` Model TokenizerModel `json:"model"`
PreTokenizer struct {
PreTokenziers []struct {
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
} }
type TokenizerModel struct { type TokenizerModel struct {
@ -26,47 +41,69 @@ type Token struct {
UserDefined bool UserDefined bool
} }
func (t *Tokenizer) getMaxID() int { func (t *Token) Type() int32 {
var maxID int switch {
for _, v := range t.Model.Vocab { case t.Special:
maxID = max(maxID, v) return 3
case t.UserDefined:
return 4
default:
return 1
}
} }
for _, v := range t.AddedTokens { func (t *Tokenizer) maxID() int {
maxID = max(maxID, v.ID) return max(
} slices.Max(maps.Values(t.Model.Vocab)),
return maxID slices.MaxFunc(t.AddedTokens, func(a, b Token) int {
return cmp.Compare(a.ID, b.ID)
}).ID,
)
} }
func newTokenizer(dirpath string) (*Tokenizer, error) { func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
f, err := os.Open(dirpath) f, err := os.Open(dirpath)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer f.Close() defer f.Close()
data, err := ioutil.ReadAll(f) var t Tokenizer
if err != nil { if err := json.NewDecoder(f).Decode(&t); err != nil {
return nil, err return "", nil, nil, err
} }
var tdata Tokenizer tokens = make([]Token, t.maxID()+1)
for k, v := range t.Model.Vocab {
if err := json.Unmarshal(data, &tdata); err != nil { tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
return nil, err
} }
maxID := tdata.getMaxID() for _, v := range t.AddedTokens {
tdata.Model.Tokens = make([]Token, maxID+1)
for k, v := range tdata.Model.Vocab {
tdata.Model.Tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
}
for _, v := range tdata.AddedTokens {
v.UserDefined = true v.UserDefined = true
tdata.Model.Tokens[v.ID] = v tokens[v.ID] = v
} }
return &tdata, nil sha256sum := sha256.New()
for _, pt := range t.PreTokenizer.PreTokenziers {
switch pt.Type {
case "Split":
if pt.Pattern.Regex != "" {
sha256sum.Write([]byte(pt.Pattern.Regex))
}
}
}
switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
pre = "llama-bpe"
case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
pre = "deepseek-llm"
case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
pre = "deepseek-coder"
default:
slog.Warn("unknown pretokenizer, using default", "digest", digest)
pre = "default"
}
return pre, tokens, t.Model.Merges, nil
} }

View file

@ -480,6 +480,7 @@ var ggufKVOrder = map[string][]string{
"gemma.attention.key_length", "gemma.attention.key_length",
"gemma.attention.value_length", "gemma.attention.value_length",
"general.file_type", "general.file_type",
"tokenizer.ggml.pre",
"tokenizer.ggml.model", "tokenizer.ggml.model",
"tokenizer.ggml.tokens", "tokenizer.ggml.tokens",
"tokenizer.ggml.scores", "tokenizer.ggml.scores",