Merge pull request #6522 from ollama/mxyng/detect-chat
detect chat template from configs that contain lists
This commit is contained in:
commit
9cfd2dd3e3
3 changed files with 225 additions and 5 deletions
|
@ -89,7 +89,7 @@ func TestMain(m *testing.M) {
|
||||||
os.Exit(m.Run())
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertFull(t *testing.T) {
|
func TestConvertModel(t *testing.T) {
|
||||||
cases := []string{
|
cases := []string{
|
||||||
"Meta-Llama-3-8B-Instruct",
|
"Meta-Llama-3-8B-Instruct",
|
||||||
"Meta-Llama-3.1-8B-Instruct",
|
"Meta-Llama-3.1-8B-Instruct",
|
||||||
|
|
|
@ -100,8 +100,21 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if template, ok := p["chat_template"]; ok {
|
if template, ok := p["chat_template"]; ok {
|
||||||
if err := json.Unmarshal(template, &t.Template); err != nil {
|
var s []struct {
|
||||||
return nil, err
|
Name string `json:"name"`
|
||||||
|
Template string `json:"template"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(template, &t.Template); err == nil {
|
||||||
|
// noop
|
||||||
|
} else if err := json.Unmarshal(template, &s); err == nil {
|
||||||
|
for _, e := range s {
|
||||||
|
if e.Name == "default" {
|
||||||
|
t.Template = e.Template
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("invalid chat_template: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,7 +154,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenizer struct {
|
type tokenizer struct {
|
||||||
Version string `json:"version"`
|
|
||||||
AddedTokens []token `json:"added_tokens"`
|
AddedTokens []token `json:"added_tokens"`
|
||||||
Model struct {
|
Model struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
@ -239,7 +251,7 @@ func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
|
||||||
return pattern.Func(fsys)
|
return pattern.Func(fsys)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("unknown tensor format")
|
return nil, errors.New("unknown tokenizer format")
|
||||||
}
|
}
|
||||||
|
|
||||||
type SpecialVocabulary struct {
|
type SpecialVocabulary struct {
|
||||||
|
|
208
convert/tokenizer_test.go
Normal file
208
convert/tokenizer_test.go
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for k, v := range files {
|
||||||
|
if err := func() error {
|
||||||
|
f, err := os.Create(filepath.Join(dir, k))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(f, v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}(); err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.DirFS(dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokenizer(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
fsys fs.FS
|
||||||
|
specialTokenTypes []string
|
||||||
|
want *Tokenizer
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string chat template",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{}`),
|
||||||
|
"tokenizer_config.json": strings.NewReader(`{
|
||||||
|
"chat_template": "<default template>"
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||||
|
Pre: "default",
|
||||||
|
Template: "<default template>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list chat template",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{}`),
|
||||||
|
"tokenizer_config.json": strings.NewReader(`{
|
||||||
|
"chat_template": [
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"template": "<default template>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "tools",
|
||||||
|
"template": "<tools template>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{Model: "gpt2"},
|
||||||
|
Pre: "default",
|
||||||
|
Template: "<default template>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "added tokens",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"added_tokens": [
|
||||||
|
{
|
||||||
|
"id": 999,
|
||||||
|
"content": "<unused999>",
|
||||||
|
"special": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{"<unused999>"},
|
||||||
|
Scores: []float32{999},
|
||||||
|
Types: []int32{4},
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "added tokens overlap vocab",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"added_tokens": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"content": "<pad>",
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": {
|
||||||
|
"vocab": {
|
||||||
|
"<pad>": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{"<pad>"},
|
||||||
|
Scores: []float32{0},
|
||||||
|
Types: []int32{3},
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special token types",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"added_tokens": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"content": "<pad>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"content": "<eos>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"content": "<bos>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"content": "<unk>",
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": {
|
||||||
|
"vocab": {
|
||||||
|
"<pad>": 0,
|
||||||
|
"<eos>": 1,
|
||||||
|
"<bos>": 2,
|
||||||
|
"<unk>": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
"tokenizer_config.json": strings.NewReader(`{
|
||||||
|
"add_bos_token": true,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>",
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"unk_token": "<unk>"
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{"<pad>", "<eos>", "<bos>", "<unk>"},
|
||||||
|
Scores: []float32{0, 1, 2, 3},
|
||||||
|
Types: []int32{3, 3, 3, 3},
|
||||||
|
},
|
||||||
|
SpecialVocabulary: []*SpecialVocabulary{
|
||||||
|
{Type: "pad", Content: "<pad>", ID: 0, AddToken: false},
|
||||||
|
{Type: "eos", Content: "<eos>", ID: 1, AddToken: false},
|
||||||
|
{Type: "bos", Content: "<bos>", ID: 2, AddToken: true},
|
||||||
|
{Type: "unk", Content: "<unk>", ID: 3, AddToken: false},
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
|
||||||
|
t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue