This commit is contained in:
Michael Yang 2024-04-24 16:12:56 -07:00
parent 238715037d
commit abe614c705
2 changed files with 118 additions and 58 deletions

View file

@ -26,7 +26,10 @@ const (
stateComment stateComment
) )
var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") var (
errMissingFrom = errors.New("no FROM line")
errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
)
func Parse(r io.Reader) (cmds []Command, err error) { func Parse(r io.Reader) (cmds []Command, err error) {
var cmd Command var cmd Command
@ -123,7 +126,7 @@ func Parse(r io.Reader) (cmds []Command, err error) {
} }
} }
return nil, errors.New("no FROM line") return nil, errMissingFrom
} }
func parseRuneForState(r rune, cs state) (state, rune, error) { func parseRuneForState(r rune, cs state) (state, rune, error) {

View file

@ -11,7 +11,6 @@ import (
) )
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
input := ` input := `
FROM model1 FROM model1
ADAPTER adapter1 ADAPTER adapter1
@ -38,21 +37,62 @@ TEMPLATE template1
assert.Equal(t, expectedCommands, commands) assert.Equal(t, expectedCommands, commands)
} }
func TestParserNoFromLine(t *testing.T) { func TestParserFrom(t *testing.T) {
var cases = []struct {
input string
expected []Command
err error
}{
{
"FROM foo",
[]Command{{Name: "model", Args: "foo"}},
nil,
},
{
"FROM /path/to/model",
[]Command{{Name: "model", Args: "/path/to/model"}},
nil,
},
{
"FROM /path/to/model/fp16.bin",
[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
nil,
},
{
"FROM llama3:latest",
[]Command{{Name: "model", Args: "llama3:latest"}},
nil,
},
{
"FROM llama3:7b-instruct-q4_K_M",
[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
nil,
},
{
"", nil, errMissingFrom,
},
{
"PARAMETER param1 value1",
nil,
errMissingFrom,
},
{
"PARAMETER param1 value1\nFROM foo",
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
nil,
},
}
input := ` for _, c := range cases {
PARAMETER param1 value1 t.Run("", func(t *testing.T) {
PARAMETER param2 value2 commands, err := Parse(strings.NewReader(c.input))
` assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
reader := strings.NewReader(input) })
}
_, err := Parse(reader)
assert.ErrorContains(t, err, "no FROM line")
} }
func TestParserParametersMissingValue(t *testing.T) { func TestParserParametersMissingValue(t *testing.T) {
input := ` input := `
FROM foo FROM foo
PARAMETER param1 PARAMETER param1
@ -261,6 +301,17 @@ TEMPLATE "'"
}, },
nil, nil,
}, },
{
`
FROM foo
TEMPLATE """''"'""'""'"'''''""'""'"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: `''"'""'""'"'''''""'""'`},
},
nil,
},
} }
for _, c := range cases { for _, c := range cases {
@ -273,59 +324,65 @@ TEMPLATE "'"
} }
func TestParserParameters(t *testing.T) { func TestParserParameters(t *testing.T) {
var cases = []string{ var cases = map[string]struct {
"numa true", name, value string
"num_ctx 1", }{
"num_batch 1", "numa true": {"numa", "true"},
"num_gqa 1", "num_ctx 1": {"num_ctx", "1"},
"num_gpu 1", "num_batch 1": {"num_batch", "1"},
"main_gpu 1", "num_gqa 1": {"num_gqa", "1"},
"low_vram true", "num_gpu 1": {"num_gpu", "1"},
"f16_kv true", "main_gpu 1": {"main_gpu", "1"},
"logits_all true", "low_vram true": {"low_vram", "true"},
"vocab_only true", "f16_kv true": {"f16_kv", "true"},
"use_mmap true", "logits_all true": {"logits_all", "true"},
"use_mlock true", "vocab_only true": {"vocab_only", "true"},
"num_thread 1", "use_mmap true": {"use_mmap", "true"},
"num_keep 1", "use_mlock true": {"use_mlock", "true"},
"seed 1", "num_thread 1": {"num_thread", "1"},
"num_predict 1", "num_keep 1": {"num_keep", "1"},
"top_k 1", "seed 1": {"seed", "1"},
"top_p 1.0", "num_predict 1": {"num_predict", "1"},
"tfs_z 1.0", "top_k 1": {"top_k", "1"},
"typical_p 1.0", "top_p 1.0": {"top_p", "1.0"},
"repeat_last_n 1", "tfs_z 1.0": {"tfs_z", "1.0"},
"temperature 1.0", "typical_p 1.0": {"typical_p", "1.0"},
"repeat_penalty 1.0", "repeat_last_n 1": {"repeat_last_n", "1"},
"presence_penalty 1.0", "temperature 1.0": {"temperature", "1.0"},
"frequency_penalty 1.0", "repeat_penalty 1.0": {"repeat_penalty", "1.0"},
"mirostat 1", "presence_penalty 1.0": {"presence_penalty", "1.0"},
"mirostat_tau 1.0", "frequency_penalty 1.0": {"frequency_penalty", "1.0"},
"mirostat_eta 1.0", "mirostat 1": {"mirostat", "1"},
"penalize_newline true", "mirostat_tau 1.0": {"mirostat_tau", "1.0"},
"stop foo", "mirostat_eta 1.0": {"mirostat_eta", "1.0"},
"penalize_newline true": {"penalize_newline", "true"},
"stop ### User:": {"stop", "### User:"},
"stop ### User: ": {"stop", "### User: "},
"stop \"### User:\"": {"stop", "### User:"},
"stop \"### User: \"": {"stop", "### User: "},
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
"stop <|endoftext|>": {"stop", "<|endoftext|>"},
"stop <|eot_id|>": {"stop", "<|eot_id|>"},
"stop </s>": {"stop", "</s>"},
} }
for _, c := range cases { for k, v := range cases {
t.Run(c, func(t *testing.T) { t.Run(k, func(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", c) fmt.Fprintln(&b, "PARAMETER", k)
t.Logf("input: %s", b.String()) commands, err := Parse(&b)
_, err := Parse(&b)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []Command{
{Name: "model", Args: "foo"},
{Name: v.name, Args: v.value},
}, commands)
}) })
} }
} }
func TestParserOnlyFrom(t *testing.T) {
commands, err := Parse(strings.NewReader("FROM foo"))
assert.Nil(t, err)
expected := []Command{{Name: "model", Args: "foo"}}
assert.Equal(t, expected, commands)
}
func TestParserComments(t *testing.T) { func TestParserComments(t *testing.T) {
var cases = []struct { var cases = []struct {
input string input string