refactor modelfile parser

This commit is contained in:
Michael Yang 2024-04-22 15:37:14 -07:00
parent f0c454ab57
commit c0a00f68ae
3 changed files with 485 additions and 130 deletions

View file

@ -6,8 +6,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"slices" "slices"
"strconv"
"strings"
) )
type Command struct { type Command struct {
@ -15,118 +16,219 @@ type Command struct {
Args string Args string
} }
func (c *Command) Reset() { type state int
c.Name = ""
c.Args = "" const (
stateNil state = iota
stateName
stateValue
stateParameter
stateMessage
stateComment
)
var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
func Parse(r io.Reader) (cmds []Command, err error) {
var cmd Command
var curr state
var b bytes.Buffer
var role string
br := bufio.NewReader(r)
for {
r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
} }
func Parse(reader io.Reader) ([]Command, error) { next, r, err := parseRuneForState(r, curr)
var commands []Command if errors.Is(err, io.ErrUnexpectedEOF) {
var command, modelCommand Command return nil, fmt.Errorf("%w: %s", err, b.String())
} else if err != nil {
scanner := bufio.NewScanner(reader) return nil, err
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
scanner.Split(scanModelfile)
for scanner.Scan() {
line := scanner.Bytes()
fields := bytes.SplitN(line, []byte(" "), 2)
if len(fields) == 0 || len(fields[0]) == 0 {
continue
} }
switch string(bytes.ToUpper(fields[0])) { if next != curr {
case "FROM": switch curr {
command.Name = "model" case stateName, stateParameter:
command.Args = string(bytes.TrimSpace(fields[1])) switch s := strings.ToLower(b.String()); s {
// copy command for validation case "from":
modelCommand = command cmd.Name = "model"
case "ADAPTER": case "parameter":
command.Name = string(bytes.ToLower(fields[0])) next = stateParameter
command.Args = string(bytes.TrimSpace(fields[1])) case "message":
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": next = stateMessage
command.Name = string(bytes.ToLower(fields[0])) fallthrough
command.Args = string(fields[1])
case "PARAMETER":
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("missing value for %s", fields)
}
command.Name = string(fields[0])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { cmd.Name = s
// log a warning for unknown commands
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
} }
case stateMessage:
if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) {
return nil, errInvalidRole
}
role = b.String()
case stateComment, stateNil:
// pass
case stateValue:
s := b.String()
s, ok := unquote(b.String())
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue continue
} }
commands = append(commands, command) if role != "" {
command.Reset() s = role + ": " + s
role = ""
} }
if modelCommand.Args == "" { cmd.Args = s
return nil, errors.New("no FROM line for the model was specified") cmds = append(cmds, cmd)
} }
return commands, scanner.Err() b.Reset()
curr = next
} }
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { if strconv.IsPrint(r) {
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) if _, err := b.WriteRune(r); err != nil {
if err != nil { return nil, err
return 0, nil, err
} }
if advance > 0 && token != nil {
return advance, token, nil
}
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
if err != nil {
return 0, nil, err
}
if advance > 0 && token != nil {
return advance, token, nil
}
return bufio.ScanLines(data, atEOF)
}
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
newline := bytes.IndexByte(data, '\n')
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
end := bytes.Index(data[start+len(openBytes):], closeBytes)
if end < 0 {
if atEOF {
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
} else {
return 0, nil, nil
} }
} }
n := start + len(openBytes) + end + len(closeBytes) // flush the buffer
switch curr {
newData := data[:start] case stateComment, stateNil:
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) // pass; nothing to flush
return n, newData, nil case stateValue:
if _, ok := unquote(b.String()); !ok {
return nil, io.ErrUnexpectedEOF
} }
return 0, nil, nil cmd.Args = b.String()
cmds = append(cmds, cmd)
default:
return nil, io.ErrUnexpectedEOF
}
for _, cmd := range cmds {
if cmd.Name == "model" {
return cmds, nil
}
}
return nil, errors.New("no FROM line")
}
func parseRuneForState(r rune, cs state) (state, rune, error) {
switch cs {
case stateNil:
switch {
case r == '#':
return stateComment, 0, nil
case isSpace(r), isNewline(r):
return stateNil, 0, nil
default:
return stateName, r, nil
}
case stateName:
switch {
case isAlpha(r):
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, errors.New("invalid")
}
case stateValue:
switch {
case isNewline(r):
return stateNil, r, nil
case isSpace(r):
return stateNil, r, nil
default:
return stateValue, r, nil
}
case stateParameter:
switch {
case isAlpha(r), isNumber(r), r == '_':
return stateParameter, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateMessage:
switch {
case isAlpha(r):
return stateMessage, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateComment:
switch {
case isNewline(r):
return stateNil, 0, nil
default:
return stateComment, 0, nil
}
default:
return stateNil, 0, errors.New("")
}
}
func unquote(s string) (string, bool) {
if len(s) == 0 {
return "", false
}
// TODO: single quotes
if len(s) >= 3 && s[:3] == `"""` {
if len(s) >= 6 && s[len(s)-3:] == `"""` {
return s[3 : len(s)-3], true
}
return "", false
}
if len(s) >= 1 && s[0] == '"' {
if len(s) >= 2 && s[len(s)-1] == '"' {
return s[1 : len(s)-1], true
}
return "", false
}
return s, true
}
func isAlpha(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
}
func isNumber(r rune) bool {
return r >= '0' && r <= '9'
}
func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
func isNewline(r rune) bool {
return r == '\r' || r == '\n'
}
func isValidRole(role string) bool {
return role == "system" || role == "user" || role == "assistant"
} }

View file

@ -1,13 +1,16 @@
package parser package parser
import ( import (
"bytes"
"fmt"
"io"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Parser(t *testing.T) { func TestParser(t *testing.T) {
input := ` input := `
FROM model1 FROM model1
@ -35,7 +38,7 @@ TEMPLATE template1
assert.Equal(t, expectedCommands, commands) assert.Equal(t, expectedCommands, commands)
} }
func Test_Parser_NoFromLine(t *testing.T) { func TestParserNoFromLine(t *testing.T) {
input := ` input := `
PARAMETER param1 value1 PARAMETER param1 value1
@ -48,7 +51,7 @@ PARAMETER param2 value2
assert.ErrorContains(t, err, "no FROM line") assert.ErrorContains(t, err, "no FROM line")
} }
func Test_Parser_MissingValue(t *testing.T) { func TestParserParametersMissingValue(t *testing.T) {
input := ` input := `
FROM foo FROM foo
@ -58,41 +61,292 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := Parse(reader) _, err := Parse(reader)
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
} }
func Test_Parser_Messages(t *testing.T) { func TestParserMessages(t *testing.T) {
var cases = []struct {
input := ` input string
expected []Command
err error
}{
{
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there! MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things! MESSAGE assistant Hello, I want to parse all the things!
` `,
[]Command{
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, {Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"}, {Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
} },
nil,
assert.Equal(t, expectedCommands, commands) },
} {
`
func Test_Parser_Messages_BadRole(t *testing.T) { FROM foo
MESSAGE system """
input := ` You are a multiline Parser. Always Parse things.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE badguy I'm a bad guy! MESSAGE badguy I'm a bad guy!
`,
nil,
errInvalidRole,
},
{
` `
FROM foo
reader := strings.NewReader(input) MESSAGE system
_, err := Parse(reader) `,
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
MESSAGE system`,
nil,
io.ErrUnexpectedEOF,
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParserQuoted(t *testing.T) {
var cases = []struct {
multiline string
expected []Command
err error
}{
{
`
FROM foo
TEMPLATE """
This is a
multiline template.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\nThis is a\nmultiline template.\n"},
},
nil,
},
{
`
FROM foo
TEMPLATE """
This is a
multiline template."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\nThis is a\nmultiline template."},
},
nil,
},
{
`
FROM foo
TEMPLATE """This is a
multiline template."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "This is a\nmultiline template."},
},
nil,
},
{
`
FROM foo
TEMPLATE """This is a multiline template."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "This is a multiline template."},
},
nil,
},
{
`
FROM foo
TEMPLATE """This is a multiline template.""
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
TEMPLATE "
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
TEMPLATE """
This is a multiline template with "quotes".
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"},
},
nil,
},
{
`
FROM foo
TEMPLATE """"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: ""},
},
nil,
},
{
`
FROM foo
TEMPLATE ""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: ""},
},
nil,
},
{
`
FROM foo
TEMPLATE "'"
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "'"},
},
nil,
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParserParameters(t *testing.T) {
var cases = []string{
"numa true",
"num_ctx 1",
"num_batch 1",
"num_gqa 1",
"num_gpu 1",
"main_gpu 1",
"low_vram true",
"f16_kv true",
"logits_all true",
"vocab_only true",
"use_mmap true",
"use_mlock true",
"num_thread 1",
"num_keep 1",
"seed 1",
"num_predict 1",
"top_k 1",
"top_p 1.0",
"tfs_z 1.0",
"typical_p 1.0",
"repeat_last_n 1",
"temperature 1.0",
"repeat_penalty 1.0",
"presence_penalty 1.0",
"frequency_penalty 1.0",
"mirostat 1",
"mirostat_tau 1.0",
"mirostat_eta 1.0",
"penalize_newline true",
"stop foo",
}
for _, c := range cases {
t.Run(c, func(t *testing.T) {
var b bytes.Buffer
fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", c)
t.Logf("input: %s", b.String())
_, err := Parse(&b)
assert.Nil(t, err)
})
}
}
func TestParserOnlyFrom(t *testing.T) {
commands, err := Parse(strings.NewReader("FROM foo"))
assert.Nil(t, err)
expected := []Command{{Name: "model", Args: "foo"}}
assert.Equal(t, expected, commands)
}
func TestParserComments(t *testing.T) {
var cases = []struct {
input string
expected []Command
}{
{
`
# comment
FROM foo
`,
[]Command{
{Name: "model", Args: "foo"},
},
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input))
assert.Nil(t, err)
assert.Equal(t, c.expected, commands)
})
}
} }

View file

@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
if tc.Expected != nil { if tc.Expected != nil {
tc.Expected(t, resp) tc.Expected(t, resp)
} }
} }
} }