diff --git a/parser/parser.go b/parser/parser.go index 1b80ebec..a8133d78 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -26,7 +26,10 @@ const ( stateComment ) -var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") +var ( + errMissingFrom = errors.New("no FROM line") + errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") +) func Parse(r io.Reader) (cmds []Command, err error) { var cmd Command @@ -123,7 +126,7 @@ func Parse(r io.Reader) (cmds []Command, err error) { } } - return nil, errors.New("no FROM line") + return nil, errMissingFrom } func parseRuneForState(r rune, cs state) (state, rune, error) { diff --git a/parser/parser_test.go b/parser/parser_test.go index 09ed2b92..94b4e8ad 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -11,7 +11,6 @@ import ( ) func TestParser(t *testing.T) { - input := ` FROM model1 ADAPTER adapter1 @@ -38,21 +37,62 @@ TEMPLATE template1 assert.Equal(t, expectedCommands, commands) } -func TestParserNoFromLine(t *testing.T) { +func TestParserFrom(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + "FROM foo", + []Command{{Name: "model", Args: "foo"}}, + nil, + }, + { + "FROM /path/to/model", + []Command{{Name: "model", Args: "/path/to/model"}}, + nil, + }, + { + "FROM /path/to/model/fp16.bin", + []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}}, + nil, + }, + { + "FROM llama3:latest", + []Command{{Name: "model", Args: "llama3:latest"}}, + nil, + }, + { + "FROM llama3:7b-instruct-q4_K_M", + []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}}, + nil, + }, + { + "", nil, errMissingFrom, + }, + { + "PARAMETER param1 value1", + nil, + errMissingFrom, + }, + { + "PARAMETER param1 value1\nFROM foo", + []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, + nil, + }, + } - input := ` -PARAMETER param1 value1 -PARAMETER param2 value2 -` - - reader := strings.NewReader(input) - - _, err := Parse(reader) - assert.ErrorContains(t, err, "no FROM line") + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } } func TestParserParametersMissingValue(t *testing.T) { - input := ` FROM foo PARAMETER param1 @@ -261,6 +301,17 @@ TEMPLATE "'" }, nil, }, + { + ` +FROM foo +TEMPLATE """''"'""'""'"'''''""'""'""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: `''"'""'""'"'''''""'""'`}, + }, + nil, + }, } for _, c := range cases { @@ -273,59 +324,65 @@ TEMPLATE "'" } func TestParserParameters(t *testing.T) { - var cases = []string{ - "numa true", - "num_ctx 1", - "num_batch 1", - "num_gqa 1", - "num_gpu 1", - "main_gpu 1", - "low_vram true", - "f16_kv true", - "logits_all true", - "vocab_only true", - "use_mmap true", - "use_mlock true", - "num_thread 1", - "num_keep 1", - "seed 1", - "num_predict 1", - "top_k 1", - "top_p 1.0", - "tfs_z 1.0", - "typical_p 1.0", - "repeat_last_n 1", - "temperature 1.0", - "repeat_penalty 1.0", - "presence_penalty 1.0", - "frequency_penalty 1.0", - "mirostat 1", - "mirostat_tau 1.0", - "mirostat_eta 1.0", - "penalize_newline true", - "stop foo", + var cases = map[string]struct { + name, value string + }{ + "numa true": {"numa", "true"}, + "num_ctx 1": {"num_ctx", "1"}, + "num_batch 1": {"num_batch", "1"}, + "num_gqa 1": {"num_gqa", "1"}, + "num_gpu 1": {"num_gpu", "1"}, + "main_gpu 1": {"main_gpu", "1"}, + "low_vram true": {"low_vram", "true"}, + "f16_kv true": {"f16_kv", "true"}, + "logits_all true": {"logits_all", "true"}, + "vocab_only true": {"vocab_only", "true"}, + "use_mmap true": {"use_mmap", "true"}, + "use_mlock true": {"use_mlock", "true"}, + "num_thread 1": {"num_thread", "1"}, + "num_keep 1": {"num_keep", "1"}, + "seed 1": {"seed", "1"}, + "num_predict 1": {"num_predict", "1"}, + "top_k 1": {"top_k", "1"}, + "top_p 1.0": {"top_p", "1.0"}, + "tfs_z 1.0": {"tfs_z", "1.0"}, + "typical_p 1.0": {"typical_p", "1.0"}, + "repeat_last_n 1": {"repeat_last_n", "1"}, + "temperature 1.0": {"temperature", "1.0"}, + "repeat_penalty 1.0": {"repeat_penalty", "1.0"}, + "presence_penalty 1.0": {"presence_penalty", "1.0"}, + "frequency_penalty 1.0": {"frequency_penalty", "1.0"}, + "mirostat 1": {"mirostat", "1"}, + "mirostat_tau 1.0": {"mirostat_tau", "1.0"}, + "mirostat_eta 1.0": {"mirostat_eta", "1.0"}, + "penalize_newline true": {"penalize_newline", "true"}, + "stop ### User:": {"stop", "### User:"}, + "stop ### User: ": {"stop", "### User: "}, + "stop \"### User:\"": {"stop", "### User:"}, + "stop \"### User: \"": {"stop", "### User: "}, + "stop \"\"\"### User:\"\"\"": {"stop", "### User:"}, + "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"}, + "stop <|endoftext|>": {"stop", "<|endoftext|>"}, + "stop <|eot_id|>": {"stop", "<|eot_id|>"}, + "stop ": {"stop", ""}, } - for _, c := range cases { - t.Run(c, func(t *testing.T) { + for k, v := range cases { + t.Run(k, func(t *testing.T) { var b bytes.Buffer fmt.Fprintln(&b, "FROM foo") - fmt.Fprintln(&b, "PARAMETER", c) - t.Logf("input: %s", b.String()) - _, err := Parse(&b) + fmt.Fprintln(&b, "PARAMETER", k) + commands, err := Parse(&b) assert.Nil(t, err) + + assert.Equal(t, []Command{ + {Name: "model", Args: "foo"}, + {Name: v.name, Args: v.value}, + }, commands) }) } } -func TestParserOnlyFrom(t *testing.T) { - commands, err := Parse(strings.NewReader("FROM foo")) - assert.Nil(t, err) - - expected := []Command{{Name: "model", Args: "foo"}} - assert.Equal(t, expected, commands) -} - func TestParserComments(t *testing.T) { var cases = []struct { input string