bpe pretokenizer
This commit is contained in:
parent
2d315ba9a9
commit
547132e820
4 changed files with 83 additions and 59 deletions
|
@ -37,6 +37,8 @@ type Params struct {
|
||||||
Experts int `json:"num_local_experts"`
|
Experts int `json:"num_local_experts"`
|
||||||
ExpertsUsed int `json:"num_experts_per_tok"`
|
ExpertsUsed int `json:"num_experts_per_tok"`
|
||||||
|
|
||||||
|
PreTokenizer string
|
||||||
|
|
||||||
ByteOrder
|
ByteOrder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,9 @@ package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -134,44 +134,27 @@ func (m *LlamaModel) GetTensors() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *LlamaModel) LoadVocab() error {
|
func (m *LlamaModel) LoadVocab() error {
|
||||||
v := &Vocab{
|
v := &Vocab{}
|
||||||
Tokens: []string{},
|
|
||||||
Types: []int32{},
|
|
||||||
Merges: []string{},
|
|
||||||
}
|
|
||||||
|
|
||||||
tokpath := filepath.Join(m.Path, "tokenizer.json")
|
tokpath := filepath.Join(m.Path, "tokenizer.json")
|
||||||
slog.Debug(fmt.Sprintf("looking for %s", tokpath))
|
pre, ts, merges, err := parseTokens(tokpath)
|
||||||
if _, err := os.Stat(tokpath); !os.IsNotExist(err) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
t, err := newTokenizer(tokpath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tok := range t.Model.Tokens {
|
|
||||||
v.Tokens = append(v.Tokens, tok.Content)
|
|
||||||
var tokType int32
|
|
||||||
switch {
|
|
||||||
case tok.Special:
|
|
||||||
tokType = 3
|
|
||||||
case tok.UserDefined:
|
|
||||||
tokType = 4
|
|
||||||
default:
|
|
||||||
tokType = 1
|
|
||||||
}
|
|
||||||
v.Types = append(v.Types, tokType)
|
|
||||||
}
|
|
||||||
v.Merges = t.Model.Merges
|
|
||||||
} else {
|
|
||||||
slog.Debug("loading sentence piece vocab")
|
|
||||||
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
for _, t := range ts {
|
||||||
|
v.Tokens = append(v.Tokens, t.Content)
|
||||||
|
v.Types = append(v.Types, t.Type())
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug("vocab loaded")
|
m.Params.PreTokenizer = pre
|
||||||
|
v.Merges = merges
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Vocab = v
|
m.Vocab = v
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -194,6 +177,7 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||||
"general.file_type": uint32(2),
|
"general.file_type": uint32(2),
|
||||||
"tokenizer.ggml.model": "gpt2",
|
"tokenizer.ggml.model": "gpt2",
|
||||||
|
|
||||||
|
"tokenizer.ggml.pre": m.Params.PreTokenizer,
|
||||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,30 @@
|
||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tokenizer struct {
|
type Tokenizer struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
AddedTokens []Token `json:"added_tokens"`
|
AddedTokens []Token `json:"added_tokens"`
|
||||||
Model TokenizerModel `json:"model"`
|
Model TokenizerModel `json:"model"`
|
||||||
|
|
||||||
|
PreTokenizer struct {
|
||||||
|
PreTokenziers []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Pattern struct {
|
||||||
|
Regex string `json:"Regex"`
|
||||||
|
} `json:"pattern"`
|
||||||
|
} `json:"pretokenizers"`
|
||||||
|
} `json:"pre_tokenizer"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerModel struct {
|
type TokenizerModel struct {
|
||||||
|
@ -26,47 +41,69 @@ type Token struct {
|
||||||
UserDefined bool
|
UserDefined bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tokenizer) getMaxID() int {
|
func (t *Token) Type() int32 {
|
||||||
var maxID int
|
switch {
|
||||||
for _, v := range t.Model.Vocab {
|
case t.Special:
|
||||||
maxID = max(maxID, v)
|
return 3
|
||||||
|
case t.UserDefined:
|
||||||
|
return 4
|
||||||
|
default:
|
||||||
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range t.AddedTokens {
|
|
||||||
maxID = max(maxID, v.ID)
|
|
||||||
}
|
|
||||||
return maxID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTokenizer(dirpath string) (*Tokenizer, error) {
|
func (t *Tokenizer) maxID() int {
|
||||||
|
return max(
|
||||||
|
slices.Max(maps.Values(t.Model.Vocab)),
|
||||||
|
slices.MaxFunc(t.AddedTokens, func(a, b Token) int {
|
||||||
|
return cmp.Compare(a.ID, b.ID)
|
||||||
|
}).ID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
|
||||||
f, err := os.Open(dirpath)
|
f, err := os.Open(dirpath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
data, err := ioutil.ReadAll(f)
|
var t Tokenizer
|
||||||
if err != nil {
|
if err := json.NewDecoder(f).Decode(&t); err != nil {
|
||||||
return nil, err
|
return "", nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var tdata Tokenizer
|
tokens = make([]Token, t.maxID()+1)
|
||||||
|
for k, v := range t.Model.Vocab {
|
||||||
if err := json.Unmarshal(data, &tdata); err != nil {
|
tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
maxID := tdata.getMaxID()
|
for _, v := range t.AddedTokens {
|
||||||
tdata.Model.Tokens = make([]Token, maxID+1)
|
|
||||||
|
|
||||||
for k, v := range tdata.Model.Vocab {
|
|
||||||
tdata.Model.Tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range tdata.AddedTokens {
|
|
||||||
v.UserDefined = true
|
v.UserDefined = true
|
||||||
tdata.Model.Tokens[v.ID] = v
|
tokens[v.ID] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tdata, nil
|
sha256sum := sha256.New()
|
||||||
|
for _, pt := range t.PreTokenizer.PreTokenziers {
|
||||||
|
switch pt.Type {
|
||||||
|
case "Split":
|
||||||
|
if pt.Pattern.Regex != "" {
|
||||||
|
sha256sum.Write([]byte(pt.Pattern.Regex))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
|
||||||
|
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
|
||||||
|
pre = "llama-bpe"
|
||||||
|
case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
|
||||||
|
pre = "deepseek-llm"
|
||||||
|
case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
|
||||||
|
pre = "deepseek-coder"
|
||||||
|
default:
|
||||||
|
slog.Warn("unknown pretokenizer, using default", "digest", digest)
|
||||||
|
pre = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
return pre, tokens, t.Model.Merges, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -480,6 +480,7 @@ var ggufKVOrder = map[string][]string{
|
||||||
"gemma.attention.key_length",
|
"gemma.attention.key_length",
|
||||||
"gemma.attention.value_length",
|
"gemma.attention.value_length",
|
||||||
"general.file_type",
|
"general.file_type",
|
||||||
|
"tokenizer.ggml.pre",
|
||||||
"tokenizer.ggml.model",
|
"tokenizer.ggml.model",
|
||||||
"tokenizer.ggml.tokens",
|
"tokenizer.ggml.tokens",
|
||||||
"tokenizer.ggml.scores",
|
"tokenizer.ggml.scores",
|
||||||
|
|
Loading…
Reference in a new issue