ollama/convert/tokenizer.go

267 lines
5.5 KiB
Go
Raw Normal View History

2024-05-08 16:56:18 -07:00
package convert
import (
2024-05-15 11:53:14 -07:00
"crypto/sha256"
2024-05-31 20:00:49 -07:00
"encoding/hex"
2024-05-08 16:56:18 -07:00
"encoding/json"
2024-05-31 20:00:49 -07:00
"errors"
2024-05-15 11:53:14 -07:00
"fmt"
2024-06-29 16:53:59 -07:00
"io/fs"
2024-05-15 11:53:14 -07:00
"log/slog"
2024-05-08 16:56:18 -07:00
"os"
2024-05-15 11:53:14 -07:00
"slices"
2024-06-06 08:59:04 -07:00
"golang.org/x/exp/maps"
2024-05-31 20:00:49 -07:00
)
2024-05-15 11:53:14 -07:00
2024-05-31 20:00:49 -07:00
const (
_ int32 = iota
tokenTypeNormal
tokenTypeUnknown
tokenTypeControl
tokenTypeUserDefined
tokenTypeUnused
tokenTypeByte
2024-05-08 16:56:18 -07:00
)
type Tokenizer struct {
2024-05-31 20:00:49 -07:00
*Vocabulary
SpecialVocabulary []*SpecialVocabulary
Merges []string
Pre string
Template string
}
2024-06-29 16:53:59 -07:00
func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
v, err := parseVocabulary(fsys)
2024-05-31 20:00:49 -07:00
if err != nil {
return nil, err
}
t := &Tokenizer{
Vocabulary: v,
Pre: "default",
}
addedTokens := make(map[string]token)
2024-06-29 16:53:59 -07:00
if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
2024-05-31 20:00:49 -07:00
} else if err != nil {
return nil, err
} else {
defer f.Close()
var tt tokenizer
if err := json.NewDecoder(f).Decode(&tt); err != nil {
return nil, err
}
for _, t := range tt.AddedTokens {
addedTokens[t.Content] = t
}
t.Merges = tt.Model.Merges
sha256sum := sha256.New()
for _, pt := range tt.PreTokenizer.PreTokenizers {
switch pt.Type {
case "Split":
if pt.Pattern.Regex != "" {
2024-07-08 16:59:48 -07:00
// create a checksum of all Split pretokenizers which should be sufficient
// to identify the pretokenizer
2024-05-31 20:00:49 -07:00
sha256sum.Write([]byte(pt.Pattern.Regex))
}
}
}
switch digest := hex.EncodeToString(sha256sum.Sum(nil)); digest {
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
t.Pre = "llama-bpe"
case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
t.Pre = "deepseek-llm"
case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
t.Pre = "deepseek-coder"
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
// noop, empty pretokenizer
default:
slog.Warn("unknown pretokenizer, using default", "digest", digest)
}
}
2024-06-29 16:53:59 -07:00
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
2024-05-31 20:00:49 -07:00
} else if err != nil {
return nil, err
} else {
defer f.Close()
var p map[string]json.RawMessage
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, err
}
if template, ok := p["chat_template"]; ok {
if err := json.Unmarshal(template, &t.Template); err != nil {
return nil, err
}
}
2024-07-08 16:59:48 -07:00
for _, st := range specialTokenTypes {
2024-05-31 20:00:49 -07:00
sv := SpecialVocabulary{Type: st}
if bts, ok := p[fmt.Sprintf("add_%s_token", st)]; ok {
if err := json.Unmarshal(bts, &sv.AddToken); err != nil {
return nil, err
}
}
if bts, ok := p[fmt.Sprintf("%s_token", st)]; ok {
var content string
if err := json.Unmarshal(bts, &content); err != nil {
var mm map[string]any
if err := json.Unmarshal(bts, &mm); err != nil {
continue
}
content, ok = mm["content"].(string)
if !ok {
continue
}
}
sv.Content = content
}
if id, ok := addedTokens[sv.Content]; ok {
sv.ID = id.ID
t.SpecialVocabulary = append(t.SpecialVocabulary, &sv)
}
}
}
return t, nil
}
type tokenizer struct {
Version string `json:"version"`
AddedTokens []token `json:"added_tokens"`
Model struct {
Type string `json:"type"`
Vocab map[string]int `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
2024-05-15 11:53:14 -07:00
PreTokenizer struct {
PreTokenizers []struct {
2024-05-15 11:53:14 -07:00
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
2024-05-08 16:56:18 -07:00
}
2024-05-31 20:00:49 -07:00
type token struct {
2024-05-08 16:56:18 -07:00
ID int `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
UserDefined bool
}
2024-05-31 20:00:49 -07:00
type Vocabulary struct {
Model string
Tokens []string
Scores []float32
Types []int32
2024-05-15 11:53:14 -07:00
}
2024-05-08 16:56:18 -07:00
2024-06-29 16:53:59 -07:00
func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
f, err := fsys.Open("tokenizer.json")
2024-05-08 16:56:18 -07:00
if err != nil {
2024-05-31 20:00:49 -07:00
return nil, err
2024-05-08 16:56:18 -07:00
}
defer f.Close()
2024-05-31 20:00:49 -07:00
var t tokenizer
2024-05-15 11:53:14 -07:00
if err := json.NewDecoder(f).Decode(&t); err != nil {
2024-05-31 20:00:49 -07:00
return nil, err
2024-05-08 16:56:18 -07:00
}
2024-06-06 08:59:04 -07:00
tokens := make(map[int]token, len(t.Model.Vocab))
2024-05-15 11:53:14 -07:00
for k, v := range t.Model.Vocab {
2024-06-06 08:59:04 -07:00
tokens[v] = token{
2024-05-31 20:00:49 -07:00
ID: v,
Content: k,
2024-06-06 08:59:04 -07:00
}
2024-05-08 16:56:18 -07:00
}
2024-06-06 08:59:04 -07:00
for _, token := range t.AddedTokens {
token.UserDefined = true
tokens[token.ID] = token
2024-05-15 11:53:14 -07:00
}
2024-05-08 16:56:18 -07:00
2024-06-06 08:59:04 -07:00
keys := maps.Keys(tokens)
slices.Sort(keys)
2024-05-31 20:00:49 -07:00
v := Vocabulary{Model: "gpt2"}
2024-06-06 08:59:04 -07:00
for _, k := range keys {
token := tokens[k]
v.Tokens = append(v.Tokens, token.Content)
v.Scores = append(v.Scores, float32(token.ID))
2024-05-31 20:00:49 -07:00
switch {
2024-06-06 08:59:04 -07:00
case token.Special:
2024-05-31 20:00:49 -07:00
v.Types = append(v.Types, tokenTypeControl)
2024-06-06 08:59:04 -07:00
case token.UserDefined:
2024-05-31 20:00:49 -07:00
v.Types = append(v.Types, tokenTypeUserDefined)
default:
v.Types = append(v.Types, tokenTypeNormal)
2024-05-15 11:53:14 -07:00
}
2024-05-08 16:56:18 -07:00
}
2024-05-31 20:00:49 -07:00
return &v, nil
}
2024-06-29 16:53:59 -07:00
func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
2024-07-31 15:39:11 -07:00
patterns := []struct {
Pattern string
Func func(fs.FS) (*Vocabulary, error)
}{
{"tokenizer.model", parseSentencePiece},
{"tokenizer.json", parseVocabularyFromTokenizer},
2024-05-31 20:00:49 -07:00
}
2024-07-31 15:39:11 -07:00
for _, pattern := range patterns {
if _, err := fs.Stat(fsys, pattern.Pattern); errors.Is(err, os.ErrNotExist) {
2024-07-08 16:59:48 -07:00
continue
} else if err != nil {
2024-05-31 20:00:49 -07:00
return nil, err
}
2024-07-31 15:39:11 -07:00
return pattern.Func(fsys)
2024-05-31 20:00:49 -07:00
}
return nil, errors.New("unknown tensor format")
}
type SpecialVocabulary struct {
Type string
ID int
Content string
AddToken bool
}
func (sv SpecialVocabulary) Key() string {
switch t := sv.Type; t {
case "bos", "eos", "cls", "mask":
return t
case "unk":
return "unknown"
case "sep":
//nolint:misspell // this is an upstream typo
return "seperator"
case "pad":
return "padding"
2024-05-08 16:56:18 -07:00
}
2024-05-31 20:00:49 -07:00
panic("unknown special vocabulary type")
2024-05-08 16:56:18 -07:00
}