diff --git a/parser/parser.go b/parser/parser.go index 947848b2..edb81615 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -6,8 +6,9 @@ import ( "errors" "fmt" "io" - "log/slog" "slices" + "strconv" + "strings" ) type Command struct { @@ -15,118 +16,219 @@ type Command struct { Args string } -func (c *Command) Reset() { - c.Name = "" - c.Args = "" -} +type state int -func Parse(reader io.Reader) ([]Command, error) { - var commands []Command - var command, modelCommand Command +const ( + stateNil state = iota + stateName + stateValue + stateParameter + stateMessage + stateComment +) - scanner := bufio.NewScanner(reader) - scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) - scanner.Split(scanModelfile) - for scanner.Scan() { - line := scanner.Bytes() +var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") - fields := bytes.SplitN(line, []byte(" "), 2) - if len(fields) == 0 || len(fields[0]) == 0 { - continue +func Parse(r io.Reader) (cmds []Command, err error) { + var cmd Command + var curr state + var b bytes.Buffer + var role string + + br := bufio.NewReader(r) + for { + r, _, err := br.ReadRune() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err } - switch string(bytes.ToUpper(fields[0])) { - case "FROM": - command.Name = "model" - command.Args = string(bytes.TrimSpace(fields[1])) - // copy command for validation - modelCommand = command - case "ADAPTER": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(bytes.TrimSpace(fields[1])) - case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(fields[1]) - case "PARAMETER": - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("missing value for %s", fields) + next, r, err := parseRuneForState(r, curr) + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, fmt.Errorf("%w: %s", err, b.String()) + } else if err != nil { + return nil, err + } + + if next != curr { + switch curr { + case stateName, stateParameter: + switch s := strings.ToLower(b.String()); s { + case "from": + cmd.Name = "model" + case "parameter": + next = stateParameter + case "message": + next = stateMessage + fallthrough + default: + cmd.Name = s + } + case stateMessage: + if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) { + return nil, errInvalidRole + } + + role = b.String() + case stateComment, stateNil: + // pass + case stateValue: + s := b.String() + + s, ok := unquote(b.String()) + if !ok || isSpace(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + + continue + } + + if role != "" { + s = role + ": " + s + role = "" + } + + cmd.Args = s + cmds = append(cmds, cmd) } - command.Name = string(fields[0]) - command.Args = string(bytes.TrimSpace(fields[1])) - case "EMBED": - return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") - case "MESSAGE": - command.Name = string(bytes.ToLower(fields[0])) - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("should be in the format ") + b.Reset() + curr = next + } + + if strconv.IsPrint(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err } - if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) { - return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"") - } - command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1])) + } + } + + // flush the buffer + switch curr { + case stateComment, stateNil: + // pass; nothing to flush + case stateValue: + if _, ok := unquote(b.String()); !ok { + return nil, io.ErrUnexpectedEOF + } + + cmd.Args = b.String() + cmds = append(cmds, cmd) + default: + return nil, io.ErrUnexpectedEOF + } + + for _, cmd := range cmds { + if cmd.Name == "model" { + return cmds, nil + } + } + + return nil, errors.New("no FROM line") +} + +func parseRuneForState(r rune, cs state) (state, rune, error) { + switch cs { + case stateNil: + switch { + case r == '#': + return stateComment, 0, nil + case isSpace(r), isNewline(r): + return stateNil, 0, nil default: - if !bytes.HasPrefix(fields[0], []byte("#")) { - // log a warning for unknown commands - slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0])) - } - continue + return stateName, r, nil + } + case stateName: + switch { + case isAlpha(r): + return stateName, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, errors.New("invalid") + } + case stateValue: + switch { + case isNewline(r): + return stateNil, r, nil + case isSpace(r): + return stateNil, r, nil + default: + return stateValue, r, nil + } + case stateParameter: + switch { + case isAlpha(r), isNumber(r), r == '_': + return stateParameter, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateMessage: + switch { + case isAlpha(r): + return stateMessage, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateComment: + switch { + case isNewline(r): + return stateNil, 0, nil + default: + return stateComment, 0, nil + } + default: + return stateNil, 0, errors.New("") + } +} + +func unquote(s string) (string, bool) { + if len(s) == 0 { + return "", false + } + + // TODO: single quotes + if len(s) >= 3 && s[:3] == `"""` { + if len(s) >= 6 && s[len(s)-3:] == `"""` { + return s[3 : len(s)-3], true } - commands = append(commands, command) - command.Reset() + return "", false } - if modelCommand.Args == "" { - return nil, errors.New("no FROM line for the model was specified") - } - - return commands, scanner.Err() -} - -func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { - advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - return bufio.ScanLines(data, atEOF) -} - -func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { - newline := bytes.IndexByte(data, '\n') - - if start := bytes.Index(data, openBytes); start >= 0 && start < newline { - end := bytes.Index(data[start+len(openBytes):], closeBytes) - if end < 0 { - if atEOF { - return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes) - } else { - return 0, nil, nil - } + if len(s) >= 1 && s[0] == '"' { + if len(s) >= 2 && s[len(s)-1] == '"' { + return s[1 : len(s)-1], true } - n := start + len(openBytes) + end + len(closeBytes) - - newData := data[:start] - newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) - return n, newData, nil + return "", false } - return 0, nil, nil + return s, true +} + +func isAlpha(r rune) bool { + return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' +} + +func isNumber(r rune) bool { + return r >= '0' && r <= '9' +} + +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +func isNewline(r rune) bool { + return r == '\r' || r == '\n' +} + +func isValidRole(role string) bool { + return role == "system" || role == "user" || role == "assistant" } diff --git a/parser/parser_test.go b/parser/parser_test.go index 25e849b5..09ed2b92 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,13 +1,16 @@ package parser import ( + "bytes" + "fmt" + "io" "strings" "testing" "github.com/stretchr/testify/assert" ) -func Test_Parser(t *testing.T) { +func TestParser(t *testing.T) { input := ` FROM model1 @@ -35,7 +38,7 @@ TEMPLATE template1 assert.Equal(t, expectedCommands, commands) } -func Test_Parser_NoFromLine(t *testing.T) { +func TestParserNoFromLine(t *testing.T) { input := ` PARAMETER param1 value1 @@ -48,7 +51,7 @@ PARAMETER param2 value2 assert.ErrorContains(t, err, "no FROM line") } -func Test_Parser_MissingValue(t *testing.T) { +func TestParserParametersMissingValue(t *testing.T) { input := ` FROM foo @@ -58,41 +61,292 @@ PARAMETER param1 reader := strings.NewReader(input) _, err := Parse(reader) - assert.ErrorContains(t, err, "missing value for [param1]") - + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) } -func Test_Parser_Messages(t *testing.T) { - - input := ` +func TestParserMessages(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + ` +FROM foo +MESSAGE system You are a Parser. Always Parse things. +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + }, + nil, + }, + { + ` FROM foo MESSAGE system You are a Parser. Always Parse things. MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! -` - - reader := strings.NewReader(input) - commands, err := Parse(reader) - assert.Nil(t, err) - - expectedCommands := []Command{ - {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, - {Name: "message", Args: "user: Hey there!"}, - {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, - } - - assert.Equal(t, expectedCommands, commands) -} - -func Test_Parser_Messages_BadRole(t *testing.T) { - - input := ` +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "user: Hey there!"}, + {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system """ +You are a multiline Parser. Always Parse things. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"}, + }, + nil, + }, + { + ` FROM foo MESSAGE badguy I'm a bad guy! -` +`, + nil, + errInvalidRole, + }, + { + ` +FROM foo +MESSAGE system +`, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +MESSAGE system`, + nil, + io.ErrUnexpectedEOF, + }, + } - reader := strings.NewReader(input) - _, err := Parse(reader) - assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParserQuoted(t *testing.T) { + var cases = []struct { + multiline string + expected []Command + err error + }{ + { + ` +FROM foo +TEMPLATE """ +This is a +multiline template. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\nThis is a\nmultiline template.\n"}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """ +This is a +multiline template.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\nThis is a\nmultiline template."}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """This is a +multiline template.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "This is a\nmultiline template."}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """This is a multiline template.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "This is a multiline template."}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """This is a multiline template."" + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +TEMPLATE " + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +TEMPLATE """ +This is a multiline template with "quotes". +""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE "" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE "'" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "'"}, + }, + nil, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.multiline)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParserParameters(t *testing.T) { + var cases = []string{ + "numa true", + "num_ctx 1", + "num_batch 1", + "num_gqa 1", + "num_gpu 1", + "main_gpu 1", + "low_vram true", + "f16_kv true", + "logits_all true", + "vocab_only true", + "use_mmap true", + "use_mlock true", + "num_thread 1", + "num_keep 1", + "seed 1", + "num_predict 1", + "top_k 1", + "top_p 1.0", + "tfs_z 1.0", + "typical_p 1.0", + "repeat_last_n 1", + "temperature 1.0", + "repeat_penalty 1.0", + "presence_penalty 1.0", + "frequency_penalty 1.0", + "mirostat 1", + "mirostat_tau 1.0", + "mirostat_eta 1.0", + "penalize_newline true", + "stop foo", + } + + for _, c := range cases { + t.Run(c, func(t *testing.T) { + var b bytes.Buffer + fmt.Fprintln(&b, "FROM foo") + fmt.Fprintln(&b, "PARAMETER", c) + t.Logf("input: %s", b.String()) + _, err := Parse(&b) + assert.Nil(t, err) + }) + } +} + +func TestParserOnlyFrom(t *testing.T) { + commands, err := Parse(strings.NewReader("FROM foo")) + assert.Nil(t, err) + + expected := []Command{{Name: "model", Args: "foo"}} + assert.Equal(t, expected, commands) +} + +func TestParserComments(t *testing.T) { + var cases = []struct { + input string + expected []Command + }{ + { + ` +# comment +FROM foo + `, + []Command{ + {Name: "model", Args: "foo"}, + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.Nil(t, err) + assert.Equal(t, c.expected, commands) + }) + } } diff --git a/server/routes_test.go b/server/routes_test.go index 4f907702..6ac98367 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) { if tc.Expected != nil { tc.Expected(t, resp) } - } }