refactor modelfile parser
This commit is contained in:
parent
f0c454ab57
commit
c0a00f68ae
3 changed files with 485 additions and 130 deletions
276
parser/parser.go
276
parser/parser.go
|
@ -6,8 +6,9 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Command struct {
|
||||
|
@ -15,118 +16,219 @@ type Command struct {
|
|||
Args string
|
||||
}
|
||||
|
||||
func (c *Command) Reset() {
|
||||
c.Name = ""
|
||||
c.Args = ""
|
||||
}
|
||||
type state int
|
||||
|
||||
func Parse(reader io.Reader) ([]Command, error) {
|
||||
var commands []Command
|
||||
var command, modelCommand Command
|
||||
const (
|
||||
stateNil state = iota
|
||||
stateName
|
||||
stateValue
|
||||
stateParameter
|
||||
stateMessage
|
||||
stateComment
|
||||
)
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
||||
scanner.Split(scanModelfile)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
|
||||
fields := bytes.SplitN(line, []byte(" "), 2)
|
||||
if len(fields) == 0 || len(fields[0]) == 0 {
|
||||
continue
|
||||
func Parse(r io.Reader) (cmds []Command, err error) {
|
||||
var cmd Command
|
||||
var curr state
|
||||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
br := bufio.NewReader(r)
|
||||
for {
|
||||
r, _, err := br.ReadRune()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch string(bytes.ToUpper(fields[0])) {
|
||||
case "FROM":
|
||||
command.Name = "model"
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
// copy command for validation
|
||||
modelCommand = command
|
||||
case "ADAPTER":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(fields[1])
|
||||
case "PARAMETER":
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("missing value for %s", fields)
|
||||
next, r, err := parseRuneForState(r, curr)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
command.Name = string(fields[0])
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
case "MESSAGE":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
||||
}
|
||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
||||
if next != curr {
|
||||
switch curr {
|
||||
case stateName, stateParameter:
|
||||
switch s := strings.ToLower(b.String()); s {
|
||||
case "from":
|
||||
cmd.Name = "model"
|
||||
case "parameter":
|
||||
next = stateParameter
|
||||
case "message":
|
||||
next = stateMessage
|
||||
fallthrough
|
||||
default:
|
||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||
// log a warning for unknown commands
|
||||
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
||||
cmd.Name = s
|
||||
}
|
||||
case stateMessage:
|
||||
if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) {
|
||||
return nil, errInvalidRole
|
||||
}
|
||||
|
||||
role = b.String()
|
||||
case stateComment, stateNil:
|
||||
// pass
|
||||
case stateValue:
|
||||
s := b.String()
|
||||
|
||||
s, ok := unquote(b.String())
|
||||
if !ok || isSpace(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
commands = append(commands, command)
|
||||
command.Reset()
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
role = ""
|
||||
}
|
||||
|
||||
if modelCommand.Args == "" {
|
||||
return nil, errors.New("no FROM line for the model was specified")
|
||||
cmd.Args = s
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
return commands, scanner.Err()
|
||||
b.Reset()
|
||||
curr = next
|
||||
}
|
||||
|
||||
if strconv.IsPrint(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flush the buffer
|
||||
switch curr {
|
||||
case stateComment, stateNil:
|
||||
// pass; nothing to flush
|
||||
case stateValue:
|
||||
if _, ok := unquote(b.String()); !ok {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
cmd.Args = b.String()
|
||||
cmds = append(cmds, cmd)
|
||||
default:
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Name == "model" {
|
||||
return cmds, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no FROM line")
|
||||
}
|
||||
|
||||
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
func parseRuneForState(r rune, cs state) (state, rune, error) {
|
||||
switch cs {
|
||||
case stateNil:
|
||||
switch {
|
||||
case r == '#':
|
||||
return stateComment, 0, nil
|
||||
case isSpace(r), isNewline(r):
|
||||
return stateNil, 0, nil
|
||||
default:
|
||||
return stateName, r, nil
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
case stateName:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
return stateName, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, errors.New("invalid")
|
||||
}
|
||||
|
||||
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
case stateValue:
|
||||
switch {
|
||||
case isNewline(r):
|
||||
return stateNil, r, nil
|
||||
case isSpace(r):
|
||||
return stateNil, r, nil
|
||||
default:
|
||||
return stateValue, r, nil
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
case stateParameter:
|
||||
switch {
|
||||
case isAlpha(r), isNumber(r), r == '_':
|
||||
return stateParameter, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
case stateMessage:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
return stateMessage, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
case stateComment:
|
||||
switch {
|
||||
case isNewline(r):
|
||||
return stateNil, 0, nil
|
||||
default:
|
||||
return stateComment, 0, nil
|
||||
}
|
||||
default:
|
||||
return stateNil, 0, errors.New("")
|
||||
}
|
||||
|
||||
return bufio.ScanLines(data, atEOF)
|
||||
}
|
||||
|
||||
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
newline := bytes.IndexByte(data, '\n')
|
||||
|
||||
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
||||
end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
||||
if end < 0 {
|
||||
if atEOF {
|
||||
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
||||
} else {
|
||||
return 0, nil, nil
|
||||
}
|
||||
func unquote(s string) (string, bool) {
|
||||
if len(s) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
n := start + len(openBytes) + end + len(closeBytes)
|
||||
|
||||
newData := data[:start]
|
||||
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
||||
return n, newData, nil
|
||||
// TODO: single quotes
|
||||
if len(s) >= 3 && s[:3] == `"""` {
|
||||
if len(s) >= 6 && s[len(s)-3:] == `"""` {
|
||||
return s[3 : len(s)-3], true
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
return "", false
|
||||
}
|
||||
|
||||
if len(s) >= 1 && s[0] == '"' {
|
||||
if len(s) >= 2 && s[len(s)-1] == '"' {
|
||||
return s[1 : len(s)-1], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return s, true
|
||||
}
|
||||
|
||||
func isAlpha(r rune) bool {
|
||||
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
|
||||
}
|
||||
|
||||
func isNumber(r rune) bool {
|
||||
return r >= '0' && r <= '9'
|
||||
}
|
||||
|
||||
func isSpace(r rune) bool {
|
||||
return r == ' ' || r == '\t'
|
||||
}
|
||||
|
||||
func isNewline(r rune) bool {
|
||||
return r == '\r' || r == '\n'
|
||||
}
|
||||
|
||||
func isValidRole(role string) bool {
|
||||
return role == "system" || role == "user" || role == "assistant"
|
||||
}
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
package parser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Parser(t *testing.T) {
|
||||
func TestParser(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM model1
|
||||
|
@ -35,7 +38,7 @@ TEMPLATE template1
|
|||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_NoFromLine(t *testing.T) {
|
||||
func TestParserNoFromLine(t *testing.T) {
|
||||
|
||||
input := `
|
||||
PARAMETER param1 value1
|
||||
|
@ -48,7 +51,7 @@ PARAMETER param2 value2
|
|||
assert.ErrorContains(t, err, "no FROM line")
|
||||
}
|
||||
|
||||
func Test_Parser_MissingValue(t *testing.T) {
|
||||
func TestParserParametersMissingValue(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
|
@ -58,41 +61,292 @@ PARAMETER param1
|
|||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
||||
|
||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func Test_Parser_Messages(t *testing.T) {
|
||||
|
||||
input := `
|
||||
func TestParserMessages(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
||||
|
||||
input := `
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system """
|
||||
You are a multiline Parser. Always Parse things.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE badguy I'm a bad guy!
|
||||
`
|
||||
`,
|
||||
nil,
|
||||
errInvalidRole,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
}
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.input))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserQuoted(t *testing.T) {
|
||||
var cases = []struct {
|
||||
multiline string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """
|
||||
This is a
|
||||
multiline template.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "\nThis is a\nmultiline template.\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """
|
||||
This is a
|
||||
multiline template."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "\nThis is a\nmultiline template."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """This is a
|
||||
multiline template."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "This is a\nmultiline template."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """This is a multiline template."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "This is a multiline template."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """This is a multiline template.""
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE "
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """
|
||||
This is a multiline template with "quotes".
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: ""},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE ""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: ""},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE "'"
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "'"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.multiline))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserParameters(t *testing.T) {
|
||||
var cases = []string{
|
||||
"numa true",
|
||||
"num_ctx 1",
|
||||
"num_batch 1",
|
||||
"num_gqa 1",
|
||||
"num_gpu 1",
|
||||
"main_gpu 1",
|
||||
"low_vram true",
|
||||
"f16_kv true",
|
||||
"logits_all true",
|
||||
"vocab_only true",
|
||||
"use_mmap true",
|
||||
"use_mlock true",
|
||||
"num_thread 1",
|
||||
"num_keep 1",
|
||||
"seed 1",
|
||||
"num_predict 1",
|
||||
"top_k 1",
|
||||
"top_p 1.0",
|
||||
"tfs_z 1.0",
|
||||
"typical_p 1.0",
|
||||
"repeat_last_n 1",
|
||||
"temperature 1.0",
|
||||
"repeat_penalty 1.0",
|
||||
"presence_penalty 1.0",
|
||||
"frequency_penalty 1.0",
|
||||
"mirostat 1",
|
||||
"mirostat_tau 1.0",
|
||||
"mirostat_eta 1.0",
|
||||
"penalize_newline true",
|
||||
"stop foo",
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
fmt.Fprintln(&b, "FROM foo")
|
||||
fmt.Fprintln(&b, "PARAMETER", c)
|
||||
t.Logf("input: %s", b.String())
|
||||
_, err := Parse(&b)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserOnlyFrom(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader("FROM foo"))
|
||||
assert.Nil(t, err)
|
||||
|
||||
expected := []Command{{Name: "model", Args: "foo"}}
|
||||
assert.Equal(t, expected, commands)
|
||||
}
|
||||
|
||||
func TestParserComments(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
}{
|
||||
{
|
||||
`
|
||||
# comment
|
||||
FROM foo
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
commands, err := Parse(strings.NewReader(c.input))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, c.expected, commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
|
|||
if tc.Expected != nil {
|
||||
tc.Expected(t, resp)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue