diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 653df6d2..429d36e7 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -100,8 +100,21 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } if template, ok := p["chat_template"]; ok { - if err := json.Unmarshal(template, &t.Template); err != nil { - return nil, err + var s []struct { + Name string `json:"name"` + Template string `json:"template"` + } + if err := json.Unmarshal(template, &t.Template); err == nil { + // noop + } else if err := json.Unmarshal(template, &s); err == nil { + for _, e := range s { + if e.Name == "default" { + t.Template = e.Template + break + } + } + } else { + return nil, fmt.Errorf("invalid chat_template: %w", err) } } diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go new file mode 100644 index 00000000..ed0175a4 --- /dev/null +++ b/convert/tokenizer_test.go @@ -0,0 +1,96 @@ +package convert + +import ( + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS { + t.Helper() + + for k, v := range files { + if err := func() error { + f, err := os.Create(filepath.Join(dir, k)) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(f, v); err != nil { + return err + } + + return nil + }(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + + return os.DirFS(dir) +} + +func TestParseTokenizer(t *testing.T) { + cases := []struct { + name string + fsys fs.FS + specialTokenTypes []string + want *Tokenizer + }{ + { + name: "string chat template", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{}`), + "tokenizer_config.json": strings.NewReader(`{ + "chat_template": "" + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{Model: "gpt2"}, + Pre: "default", + Template: "", + }, + }, + { + name: "list chat template", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{}`), + "tokenizer_config.json": strings.NewReader(`{ + "chat_template": [ + { + "name": "default", + "template": "" + }, + { + "name": "tools", + "template": "" + } + ] + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{Model: "gpt2"}, + Pre: "default", + Template: "", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.want, tokenizer); diff != "" { + t.Errorf("unexpected tokenizer (-want +got):\n%s", diff) + } + }) + } +}