2024-08-26 16:36:50 -07:00
|
|
|
package convert
|
|
|
|
|
|
|
|
import (
|
|
|
|
"io"
|
|
|
|
"io/fs"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
"strings"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
)
|
|
|
|
|
|
|
|
func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
for k, v := range files {
|
|
|
|
if err := func() error {
|
|
|
|
f, err := os.Create(filepath.Join(dir, k))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer f.Close()
|
|
|
|
|
|
|
|
if _, err := io.Copy(f, v); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}(); err != nil {
|
|
|
|
t.Fatalf("unexpected error: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return os.DirFS(dir)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestParseTokenizer(t *testing.T) {
|
|
|
|
cases := []struct {
|
|
|
|
name string
|
|
|
|
fsys fs.FS
|
|
|
|
specialTokenTypes []string
|
|
|
|
want *Tokenizer
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "string chat template",
|
|
|
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
|
|
|
"tokenizer.json": strings.NewReader(`{}`),
|
|
|
|
"tokenizer_config.json": strings.NewReader(`{
|
|
|
|
"chat_template": "<default template>"
|
|
|
|
}`),
|
|
|
|
}),
|
|
|
|
want: &Tokenizer{
|
|
|
|
Vocabulary: &Vocabulary{Model: "gpt2"},
|
|
|
|
Pre: "default",
|
|
|
|
Template: "<default template>",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "list chat template",
|
|
|
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
|
|
|
"tokenizer.json": strings.NewReader(`{}`),
|
|
|
|
"tokenizer_config.json": strings.NewReader(`{
|
|
|
|
"chat_template": [
|
|
|
|
{
|
|
|
|
"name": "default",
|
|
|
|
"template": "<default template>"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "tools",
|
|
|
|
"template": "<tools template>"
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}`),
|
|
|
|
}),
|
|
|
|
want: &Tokenizer{
|
|
|
|
Vocabulary: &Vocabulary{Model: "gpt2"},
|
|
|
|
Pre: "default",
|
|
|
|
Template: "<default template>",
|
|
|
|
},
|
|
|
|
},
|
2024-08-27 11:11:53 -07:00
|
|
|
{
|
|
|
|
name: "added tokens",
|
|
|
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
|
|
|
"tokenizer.json": strings.NewReader(`{
|
|
|
|
"added_tokens": [
|
|
|
|
{
|
|
|
|
"id": 999,
|
|
|
|
"content": "<unused999>",
|
|
|
|
"special": false
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}`),
|
|
|
|
}),
|
|
|
|
want: &Tokenizer{
|
|
|
|
Vocabulary: &Vocabulary{
|
|
|
|
Model: "gpt2",
|
|
|
|
Tokens: []string{"<unused999>"},
|
|
|
|
Scores: []float32{999},
|
|
|
|
Types: []int32{4},
|
|
|
|
},
|
|
|
|
Pre: "default",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "added tokens overlap vocab",
|
|
|
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
|
|
|
"tokenizer.json": strings.NewReader(`{
|
|
|
|
"added_tokens": [
|
|
|
|
{
|
|
|
|
"id": 0,
|
|
|
|
"content": "<pad>",
|
|
|
|
"special": true
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"model": {
|
|
|
|
"vocab": {
|
|
|
|
"<pad>": 0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}`),
|
|
|
|
}),
|
|
|
|
want: &Tokenizer{
|
|
|
|
Vocabulary: &Vocabulary{
|
|
|
|
Model: "gpt2",
|
|
|
|
Tokens: []string{"<pad>"},
|
|
|
|
Scores: []float32{0},
|
|
|
|
Types: []int32{3},
|
|
|
|
},
|
|
|
|
Pre: "default",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "special token types",
|
|
|
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
|
|
|
"tokenizer.json": strings.NewReader(`{
|
|
|
|
"added_tokens": [
|
|
|
|
{
|
|
|
|
"id": 0,
|
|
|
|
"content": "<pad>",
|
|
|
|
"special": true
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"id": 1,
|
|
|
|
"content": "<eos>",
|
|
|
|
"special": true
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"id": 2,
|
|
|
|
"content": "<bos>",
|
|
|
|
"special": true
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"id": 3,
|
|
|
|
"content": "<unk>",
|
|
|
|
"special": true
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"model": {
|
|
|
|
"vocab": {
|
|
|
|
"<pad>": 0,
|
|
|
|
"<eos>": 1,
|
|
|
|
"<bos>": 2,
|
|
|
|
"<unk>": 3
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}`),
|
|
|
|
"tokenizer_config.json": strings.NewReader(`{
|
|
|
|
"add_bos_token": true,
|
|
|
|
"add_eos_token": false,
|
|
|
|
"bos_token": "<bos>",
|
|
|
|
"eos_token": "<eos>",
|
|
|
|
"pad_token": "<pad>",
|
|
|
|
"unk_token": "<unk>"
|
|
|
|
}`),
|
|
|
|
}),
|
|
|
|
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
|
|
|
|
want: &Tokenizer{
|
|
|
|
Vocabulary: &Vocabulary{
|
|
|
|
Model: "gpt2",
|
|
|
|
Tokens: []string{"<pad>", "<eos>", "<bos>", "<unk>"},
|
|
|
|
Scores: []float32{0, 1, 2, 3},
|
|
|
|
Types: []int32{3, 3, 3, 3},
|
|
|
|
},
|
|
|
|
SpecialVocabulary: []*SpecialVocabulary{
|
|
|
|
{Type: "pad", Content: "<pad>", ID: 0, AddToken: false},
|
|
|
|
{Type: "eos", Content: "<eos>", ID: 1, AddToken: false},
|
|
|
|
{Type: "bos", Content: "<bos>", ID: 2, AddToken: true},
|
|
|
|
{Type: "unk", Content: "<unk>", ID: 3, AddToken: false},
|
|
|
|
},
|
|
|
|
Pre: "default",
|
|
|
|
},
|
|
|
|
},
|
2024-08-26 16:36:50 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
for _, tt := range cases {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("unexpected error: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
|
|
|
|
t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|