package convert import ( "io" "io/fs" "os" "path/filepath" "strings" "testing" "github.com/google/go-cmp/cmp" ) func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS { t.Helper() for k, v := range files { if err := func() error { f, err := os.Create(filepath.Join(dir, k)) if err != nil { return err } defer f.Close() if _, err := io.Copy(f, v); err != nil { return err } return nil }(); err != nil { t.Fatalf("unexpected error: %v", err) } } return os.DirFS(dir) } func TestParseTokenizer(t *testing.T) { cases := []struct { name string fsys fs.FS specialTokenTypes []string want *Tokenizer }{ { name: "string chat template", fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ "tokenizer.json": strings.NewReader(`{}`), "tokenizer_config.json": strings.NewReader(`{ "chat_template": "" }`), }), want: &Tokenizer{ Vocabulary: &Vocabulary{Model: "gpt2"}, Pre: "default", Template: "", }, }, { name: "list chat template", fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ "tokenizer.json": strings.NewReader(`{}`), "tokenizer_config.json": strings.NewReader(`{ "chat_template": [ { "name": "default", "template": "" }, { "name": "tools", "template": "" } ] }`), }), want: &Tokenizer{ Vocabulary: &Vocabulary{Model: "gpt2"}, Pre: "default", Template: "", }, }, { name: "added tokens", fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ "tokenizer.json": strings.NewReader(`{ "added_tokens": [ { "id": 999, "content": "", "special": false } ] }`), }), want: &Tokenizer{ Vocabulary: &Vocabulary{ Model: "gpt2", Tokens: []string{""}, Scores: []float32{999}, Types: []int32{4}, }, Pre: "default", }, }, { name: "added tokens overlap vocab", fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ "tokenizer.json": strings.NewReader(`{ "added_tokens": [ { "id": 0, "content": "", "special": true } ], "model": { "vocab": { "": 0 } } }`), }), want: &Tokenizer{ Vocabulary: &Vocabulary{ Model: "gpt2", Tokens: []string{""}, Scores: []float32{0}, Types: []int32{3}, }, Pre: "default", }, }, { name: "special token types", fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ "tokenizer.json": strings.NewReader(`{ "added_tokens": [ { "id": 0, "content": "", "special": true }, { "id": 1, "content": "", "special": true }, { "id": 2, "content": "", "special": true }, { "id": 3, "content": "", "special": true } ], "model": { "vocab": { "": 0, "": 1, "": 2, "": 3 } } }`), "tokenizer_config.json": strings.NewReader(`{ "add_bos_token": true, "add_eos_token": false, "bos_token": "", "eos_token": "", "pad_token": "", "unk_token": "" }`), }), specialTokenTypes: []string{"pad", "eos", "bos", "unk"}, want: &Tokenizer{ Vocabulary: &Vocabulary{ Model: "gpt2", Tokens: []string{"", "", "", ""}, Scores: []float32{0, 1, 2, 3}, Types: []int32{3, 3, 3, 3}, }, SpecialVocabulary: []*SpecialVocabulary{ {Type: "pad", Content: "", ID: 0, AddToken: false}, {Type: "eos", Content: "", ID: 1, AddToken: false}, {Type: "bos", Content: "", ID: 2, AddToken: true}, {Type: "unk", Content: "", ID: 3, AddToken: false}, }, Pre: "default", }, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes) if err != nil { t.Fatalf("unexpected error: %v", err) } if diff := cmp.Diff(tt.want, tokenizer); diff != "" { t.Errorf("unexpected tokenizer (-want +got):\n%s", diff) } }) } }