refactor modelfile parser
This commit is contained in:
parent
f0c454ab57
commit
c0a00f68ae
3 changed files with 485 additions and 130 deletions
290
parser/parser.go
290
parser/parser.go
|
@ -6,8 +6,9 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Command struct {
|
type Command struct {
|
||||||
|
@ -15,118 +16,219 @@ type Command struct {
|
||||||
Args string
|
Args string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Command) Reset() {
|
type state int
|
||||||
c.Name = ""
|
|
||||||
c.Args = ""
|
const (
|
||||||
|
stateNil state = iota
|
||||||
|
stateName
|
||||||
|
stateValue
|
||||||
|
stateParameter
|
||||||
|
stateMessage
|
||||||
|
stateComment
|
||||||
|
)
|
||||||
|
|
||||||
|
var errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||||
|
|
||||||
|
func Parse(r io.Reader) (cmds []Command, err error) {
|
||||||
|
var cmd Command
|
||||||
|
var curr state
|
||||||
|
var b bytes.Buffer
|
||||||
|
var role string
|
||||||
|
|
||||||
|
br := bufio.NewReader(r)
|
||||||
|
for {
|
||||||
|
r, _, err := br.ReadRune()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Parse(reader io.Reader) ([]Command, error) {
|
next, r, err := parseRuneForState(r, curr)
|
||||||
var commands []Command
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
var command, modelCommand Command
|
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||||
|
} else if err != nil {
|
||||||
scanner := bufio.NewScanner(reader)
|
return nil, err
|
||||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
|
||||||
scanner.Split(scanModelfile)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
|
|
||||||
fields := bytes.SplitN(line, []byte(" "), 2)
|
|
||||||
if len(fields) == 0 || len(fields[0]) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch string(bytes.ToUpper(fields[0])) {
|
if next != curr {
|
||||||
case "FROM":
|
switch curr {
|
||||||
command.Name = "model"
|
case stateName, stateParameter:
|
||||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
switch s := strings.ToLower(b.String()); s {
|
||||||
// copy command for validation
|
case "from":
|
||||||
modelCommand = command
|
cmd.Name = "model"
|
||||||
case "ADAPTER":
|
case "parameter":
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
next = stateParameter
|
||||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
case "message":
|
||||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
next = stateMessage
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
fallthrough
|
||||||
command.Args = string(fields[1])
|
|
||||||
case "PARAMETER":
|
|
||||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
|
||||||
if len(fields) < 2 {
|
|
||||||
return nil, fmt.Errorf("missing value for %s", fields)
|
|
||||||
}
|
|
||||||
|
|
||||||
command.Name = string(fields[0])
|
|
||||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
|
||||||
case "EMBED":
|
|
||||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
|
||||||
case "MESSAGE":
|
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
|
||||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
|
||||||
if len(fields) < 2 {
|
|
||||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
|
||||||
}
|
|
||||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
|
||||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
|
||||||
}
|
|
||||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
|
||||||
default:
|
default:
|
||||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
cmd.Name = s
|
||||||
// log a warning for unknown commands
|
|
||||||
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
|
||||||
}
|
}
|
||||||
|
case stateMessage:
|
||||||
|
if !slices.Contains([]string{"system", "user", "assistant"}, b.String()) {
|
||||||
|
return nil, errInvalidRole
|
||||||
|
}
|
||||||
|
|
||||||
|
role = b.String()
|
||||||
|
case stateComment, stateNil:
|
||||||
|
// pass
|
||||||
|
case stateValue:
|
||||||
|
s := b.String()
|
||||||
|
|
||||||
|
s, ok := unquote(b.String())
|
||||||
|
if !ok || isSpace(r) {
|
||||||
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
commands = append(commands, command)
|
if role != "" {
|
||||||
command.Reset()
|
s = role + ": " + s
|
||||||
|
role = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelCommand.Args == "" {
|
cmd.Args = s
|
||||||
return nil, errors.New("no FROM line for the model was specified")
|
cmds = append(cmds, cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
return commands, scanner.Err()
|
b.Reset()
|
||||||
|
curr = next
|
||||||
}
|
}
|
||||||
|
|
||||||
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
if strconv.IsPrint(r) {
|
||||||
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return 0, nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if advance > 0 && token != nil {
|
|
||||||
return advance, token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if advance > 0 && token != nil {
|
|
||||||
return advance, token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return bufio.ScanLines(data, atEOF)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
newline := bytes.IndexByte(data, '\n')
|
|
||||||
|
|
||||||
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
|
||||||
end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
|
||||||
if end < 0 {
|
|
||||||
if atEOF {
|
|
||||||
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
|
||||||
} else {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n := start + len(openBytes) + end + len(closeBytes)
|
// flush the buffer
|
||||||
|
switch curr {
|
||||||
newData := data[:start]
|
case stateComment, stateNil:
|
||||||
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
// pass; nothing to flush
|
||||||
return n, newData, nil
|
case stateValue:
|
||||||
|
if _, ok := unquote(b.String()); !ok {
|
||||||
|
return nil, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, nil, nil
|
cmd.Args = b.String()
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
default:
|
||||||
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
if cmd.Name == "model" {
|
||||||
|
return cmds, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("no FROM line")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRuneForState(r rune, cs state) (state, rune, error) {
|
||||||
|
switch cs {
|
||||||
|
case stateNil:
|
||||||
|
switch {
|
||||||
|
case r == '#':
|
||||||
|
return stateComment, 0, nil
|
||||||
|
case isSpace(r), isNewline(r):
|
||||||
|
return stateNil, 0, nil
|
||||||
|
default:
|
||||||
|
return stateName, r, nil
|
||||||
|
}
|
||||||
|
case stateName:
|
||||||
|
switch {
|
||||||
|
case isAlpha(r):
|
||||||
|
return stateName, r, nil
|
||||||
|
case isSpace(r):
|
||||||
|
return stateValue, 0, nil
|
||||||
|
default:
|
||||||
|
return stateNil, 0, errors.New("invalid")
|
||||||
|
}
|
||||||
|
case stateValue:
|
||||||
|
switch {
|
||||||
|
case isNewline(r):
|
||||||
|
return stateNil, r, nil
|
||||||
|
case isSpace(r):
|
||||||
|
return stateNil, r, nil
|
||||||
|
default:
|
||||||
|
return stateValue, r, nil
|
||||||
|
}
|
||||||
|
case stateParameter:
|
||||||
|
switch {
|
||||||
|
case isAlpha(r), isNumber(r), r == '_':
|
||||||
|
return stateParameter, r, nil
|
||||||
|
case isSpace(r):
|
||||||
|
return stateValue, 0, nil
|
||||||
|
default:
|
||||||
|
return stateNil, 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
case stateMessage:
|
||||||
|
switch {
|
||||||
|
case isAlpha(r):
|
||||||
|
return stateMessage, r, nil
|
||||||
|
case isSpace(r):
|
||||||
|
return stateValue, 0, nil
|
||||||
|
default:
|
||||||
|
return stateNil, 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
case stateComment:
|
||||||
|
switch {
|
||||||
|
case isNewline(r):
|
||||||
|
return stateNil, 0, nil
|
||||||
|
default:
|
||||||
|
return stateComment, 0, nil
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return stateNil, 0, errors.New("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unquote(s string) (string, bool) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: single quotes
|
||||||
|
if len(s) >= 3 && s[:3] == `"""` {
|
||||||
|
if len(s) >= 6 && s[len(s)-3:] == `"""` {
|
||||||
|
return s[3 : len(s)-3], true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s) >= 1 && s[0] == '"' {
|
||||||
|
if len(s) >= 2 && s[len(s)-1] == '"' {
|
||||||
|
return s[1 : len(s)-1], true
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAlpha(r rune) bool {
|
||||||
|
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNumber(r rune) bool {
|
||||||
|
return r >= '0' && r <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSpace(r rune) bool {
|
||||||
|
return r == ' ' || r == '\t'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNewline(r rune) bool {
|
||||||
|
return r == '\r' || r == '\n'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidRole(role string) bool {
|
||||||
|
return role == "system" || role == "user" || role == "assistant"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
package parser
|
package parser
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Parser(t *testing.T) {
|
func TestParser(t *testing.T) {
|
||||||
|
|
||||||
input := `
|
input := `
|
||||||
FROM model1
|
FROM model1
|
||||||
|
@ -35,7 +38,7 @@ TEMPLATE template1
|
||||||
assert.Equal(t, expectedCommands, commands)
|
assert.Equal(t, expectedCommands, commands)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Parser_NoFromLine(t *testing.T) {
|
func TestParserNoFromLine(t *testing.T) {
|
||||||
|
|
||||||
input := `
|
input := `
|
||||||
PARAMETER param1 value1
|
PARAMETER param1 value1
|
||||||
|
@ -48,7 +51,7 @@ PARAMETER param2 value2
|
||||||
assert.ErrorContains(t, err, "no FROM line")
|
assert.ErrorContains(t, err, "no FROM line")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Parser_MissingValue(t *testing.T) {
|
func TestParserParametersMissingValue(t *testing.T) {
|
||||||
|
|
||||||
input := `
|
input := `
|
||||||
FROM foo
|
FROM foo
|
||||||
|
@ -58,41 +61,292 @@ PARAMETER param1
|
||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
_, err := Parse(reader)
|
_, err := Parse(reader)
|
||||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Parser_Messages(t *testing.T) {
|
func TestParserMessages(t *testing.T) {
|
||||||
|
var cases = []struct {
|
||||||
input := `
|
input string
|
||||||
|
expected []Command
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
MESSAGE system You are a Parser. Always Parse things.
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
FROM foo
|
FROM foo
|
||||||
MESSAGE system You are a Parser. Always Parse things.
|
MESSAGE system You are a Parser. Always Parse things.
|
||||||
MESSAGE user Hey there!
|
MESSAGE user Hey there!
|
||||||
MESSAGE assistant Hello, I want to parse all the things!
|
MESSAGE assistant Hello, I want to parse all the things!
|
||||||
`
|
`,
|
||||||
|
[]Command{
|
||||||
reader := strings.NewReader(input)
|
|
||||||
commands, err := Parse(reader)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
expectedCommands := []Command{
|
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||||
{Name: "message", Args: "user: Hey there!"},
|
{Name: "message", Args: "user: Hey there!"},
|
||||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||||
}
|
},
|
||||||
|
nil,
|
||||||
assert.Equal(t, expectedCommands, commands)
|
},
|
||||||
}
|
{
|
||||||
|
`
|
||||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
FROM foo
|
||||||
|
MESSAGE system """
|
||||||
input := `
|
You are a multiline Parser. Always Parse things.
|
||||||
|
"""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
FROM foo
|
FROM foo
|
||||||
MESSAGE badguy I'm a bad guy!
|
MESSAGE badguy I'm a bad guy!
|
||||||
|
`,
|
||||||
|
nil,
|
||||||
|
errInvalidRole,
|
||||||
|
},
|
||||||
|
{
|
||||||
`
|
`
|
||||||
|
FROM foo
|
||||||
reader := strings.NewReader(input)
|
MESSAGE system
|
||||||
_, err := Parse(reader)
|
`,
|
||||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
nil,
|
||||||
|
io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
MESSAGE system`,
|
||||||
|
nil,
|
||||||
|
io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
commands, err := Parse(strings.NewReader(c.input))
|
||||||
|
assert.ErrorIs(t, err, c.err)
|
||||||
|
assert.Equal(t, c.expected, commands)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserQuoted(t *testing.T) {
|
||||||
|
var cases = []struct {
|
||||||
|
multiline string
|
||||||
|
expected []Command
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """
|
||||||
|
This is a
|
||||||
|
multiline template.
|
||||||
|
"""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: "\nThis is a\nmultiline template.\n"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """
|
||||||
|
This is a
|
||||||
|
multiline template."""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: "\nThis is a\nmultiline template."},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """This is a
|
||||||
|
multiline template."""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: "This is a\nmultiline template."},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """This is a multiline template."""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: "This is a multiline template."},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """This is a multiline template.""
|
||||||
|
`,
|
||||||
|
nil,
|
||||||
|
io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE "
|
||||||
|
`,
|
||||||
|
nil,
|
||||||
|
io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """
|
||||||
|
This is a multiline template with "quotes".
|
||||||
|
"""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: "\nThis is a multiline template with \"quotes\".\n"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE """"""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: ""},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE ""
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: ""},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
`
|
||||||
|
FROM foo
|
||||||
|
TEMPLATE "'"
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
{Name: "template", Args: "'"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
commands, err := Parse(strings.NewReader(c.multiline))
|
||||||
|
assert.ErrorIs(t, err, c.err)
|
||||||
|
assert.Equal(t, c.expected, commands)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserParameters(t *testing.T) {
|
||||||
|
var cases = []string{
|
||||||
|
"numa true",
|
||||||
|
"num_ctx 1",
|
||||||
|
"num_batch 1",
|
||||||
|
"num_gqa 1",
|
||||||
|
"num_gpu 1",
|
||||||
|
"main_gpu 1",
|
||||||
|
"low_vram true",
|
||||||
|
"f16_kv true",
|
||||||
|
"logits_all true",
|
||||||
|
"vocab_only true",
|
||||||
|
"use_mmap true",
|
||||||
|
"use_mlock true",
|
||||||
|
"num_thread 1",
|
||||||
|
"num_keep 1",
|
||||||
|
"seed 1",
|
||||||
|
"num_predict 1",
|
||||||
|
"top_k 1",
|
||||||
|
"top_p 1.0",
|
||||||
|
"tfs_z 1.0",
|
||||||
|
"typical_p 1.0",
|
||||||
|
"repeat_last_n 1",
|
||||||
|
"temperature 1.0",
|
||||||
|
"repeat_penalty 1.0",
|
||||||
|
"presence_penalty 1.0",
|
||||||
|
"frequency_penalty 1.0",
|
||||||
|
"mirostat 1",
|
||||||
|
"mirostat_tau 1.0",
|
||||||
|
"mirostat_eta 1.0",
|
||||||
|
"penalize_newline true",
|
||||||
|
"stop foo",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c, func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
fmt.Fprintln(&b, "FROM foo")
|
||||||
|
fmt.Fprintln(&b, "PARAMETER", c)
|
||||||
|
t.Logf("input: %s", b.String())
|
||||||
|
_, err := Parse(&b)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserOnlyFrom(t *testing.T) {
|
||||||
|
commands, err := Parse(strings.NewReader("FROM foo"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
expected := []Command{{Name: "model", Args: "foo"}}
|
||||||
|
assert.Equal(t, expected, commands)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserComments(t *testing.T) {
|
||||||
|
var cases = []struct {
|
||||||
|
input string
|
||||||
|
expected []Command
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
`
|
||||||
|
# comment
|
||||||
|
FROM foo
|
||||||
|
`,
|
||||||
|
[]Command{
|
||||||
|
{Name: "model", Args: "foo"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
commands, err := Parse(strings.NewReader(c.input))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, c.expected, commands)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
|
||||||
if tc.Expected != nil {
|
if tc.Expected != nil {
|
||||||
tc.Expected(t, resp)
|
tc.Expected(t, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue