ollama/convert/tokenizer.go

107 lines
2.3 KiB
Go
Raw Normal View History

2024-05-08 16:56:18 -07:00
package convert
import (
2024-05-15 11:53:14 -07:00
"cmp"
"crypto/sha256"
2024-05-08 16:56:18 -07:00
"encoding/json"
2024-05-15 11:53:14 -07:00
"fmt"
"log/slog"
2024-05-08 16:56:18 -07:00
"os"
2024-05-15 11:53:14 -07:00
"slices"
"golang.org/x/exp/maps"
2024-05-08 16:56:18 -07:00
)
type Tokenizer struct {
Version string `json:"version"`
AddedTokens []Token `json:"added_tokens"`
Model TokenizerModel `json:"model"`
2024-05-15 11:53:14 -07:00
PreTokenizer struct {
PreTokenizers []struct {
2024-05-15 11:53:14 -07:00
Type string `json:"type"`
Pattern struct {
Regex string `json:"Regex"`
} `json:"pattern"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
2024-05-08 16:56:18 -07:00
}
type TokenizerModel struct {
Type string `json:"type"`
Vocab map[string]int `json:"vocab"`
Merges []string `json:"merges"`
Tokens []Token
}
type Token struct {
ID int `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
UserDefined bool
}
2024-05-15 11:53:14 -07:00
func (t *Token) Type() int32 {
switch {
case t.Special:
2024-05-15 14:55:57 -07:00
return tokenTypeControl
2024-05-15 11:53:14 -07:00
case t.UserDefined:
2024-05-15 14:55:57 -07:00
return tokenTypeUserDefined
2024-05-15 11:53:14 -07:00
default:
2024-05-15 14:55:57 -07:00
return tokenTypeNormal
2024-05-08 16:56:18 -07:00
}
2024-05-15 11:53:14 -07:00
}
2024-05-08 16:56:18 -07:00
2024-05-15 11:53:14 -07:00
func (t *Tokenizer) maxID() int {
return max(
slices.Max(maps.Values(t.Model.Vocab)),
slices.MaxFunc(t.AddedTokens, func(a, b Token) int {
return cmp.Compare(a.ID, b.ID)
}).ID,
)
2024-05-08 16:56:18 -07:00
}
2024-05-15 11:53:14 -07:00
func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
2024-05-08 16:56:18 -07:00
f, err := os.Open(dirpath)
if err != nil {
panic(err)
}
defer f.Close()
2024-05-15 11:53:14 -07:00
var t Tokenizer
if err := json.NewDecoder(f).Decode(&t); err != nil {
return "", nil, nil, err
2024-05-08 16:56:18 -07:00
}
2024-05-15 11:53:14 -07:00
tokens = make([]Token, t.maxID()+1)
for k, v := range t.Model.Vocab {
tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
2024-05-08 16:56:18 -07:00
}
2024-05-15 11:53:14 -07:00
for _, v := range t.AddedTokens {
v.UserDefined = true
tokens[v.ID] = v
}
2024-05-08 16:56:18 -07:00
2024-05-15 11:53:14 -07:00
sha256sum := sha256.New()
for _, pt := range t.PreTokenizer.PreTokenizers {
2024-05-21 22:07:57 -07:00
if pt.Type == "Split" && pt.Pattern.Regex != "" {
sha256sum.Write([]byte(pt.Pattern.Regex))
2024-05-15 11:53:14 -07:00
}
2024-05-08 16:56:18 -07:00
}
2024-05-15 11:53:14 -07:00
switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
pre = "llama-bpe"
case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
pre = "deepseek-llm"
case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
pre = "deepseek-coder"
default:
slog.Warn("unknown pretokenizer, using default", "digest", digest)
pre = "default"
2024-05-08 16:56:18 -07:00
}
2024-05-15 11:53:14 -07:00
return pre, tokens, t.Model.Merges, nil
2024-05-08 16:56:18 -07:00
}