Merge pull request #4059 from ollama/mxyng/parser-2

rename parser to model/file
This commit is contained in:
Michael Yang 2024-05-03 13:01:22 -07:00 committed by GitHub
commit b7a87a22b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 151 additions and 124 deletions

View file

@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
@ -57,13 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
modelfile, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
return err return err
} }
defer modelfile.Close() defer f.Close()
commands, err := parser.Parse(modelfile) modelfile, err := model.ParseFile(f)
if err != nil { if err != nil {
return err return err
} }
@ -77,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
for i := range commands { for i := range modelfile.Commands {
switch commands[i].Name { switch modelfile.Commands[i].Name {
case "model", "adapter": case "model", "adapter":
path := commands[i].Args path := modelfile.Commands[i].Args
if path == "~" { if path == "~" {
path = home path = home
} else if strings.HasPrefix(path, "~/") { } else if strings.HasPrefix(path, "~/") {
@ -92,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
fi, err := os.Stat(path) fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" { if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue continue
} else if err != nil { } else if err != nil {
return err return err
@ -115,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
commands[i].Args = "@"+digest modelfile.Commands[i].Args = "@" + digest
} }
} }
@ -145,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
quantization, _ := cmd.Flags().GetString("quantization") quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization} request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil { if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err return err
} }

View file

@ -29,7 +29,6 @@ import (
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@ -63,46 +62,74 @@ func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
} }
func (m *Model) Commands() (cmds []parser.Command) { func (m *Model) String() string {
cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath}) var modelfile model.File
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "model",
Args: m.ModelPath,
})
if m.Template != "" { if m.Template != "" {
cmds = append(cmds, parser.Command{Name: "template", Args: m.Template}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "template",
Args: m.Template,
})
} }
if m.System != "" { if m.System != "" {
cmds = append(cmds, parser.Command{Name: "system", Args: m.System}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "system",
Args: m.System,
})
} }
for _, adapter := range m.AdapterPaths { for _, adapter := range m.AdapterPaths {
cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "adapter",
Args: adapter,
})
} }
for _, projector := range m.ProjectorPaths { for _, projector := range m.ProjectorPaths {
cmds = append(cmds, parser.Command{Name: "projector", Args: projector}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "projector",
Args: projector,
})
} }
for k, v := range m.Options { for k, v := range m.Options {
switch v := v.(type) { switch v := v.(type) {
case []any: case []any:
for _, s := range v { for _, s := range v {
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: k,
Args: fmt.Sprintf("%v", s),
})
} }
default: default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: k,
Args: fmt.Sprintf("%v", v),
})
} }
} }
for _, license := range m.License { for _, license := range m.License {
cmds = append(cmds, parser.Command{Name: "license", Args: license}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "license",
Args: license,
})
} }
for _, msg := range m.Messages { for _, msg := range m.Messages {
cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)}) modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "message",
Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
})
} }
return cmds return modelfile.String()
} }
type Message struct { type Message struct {
@ -329,7 +356,7 @@ func realpath(mfDir, from string) string {
return abspath return abspath
} }
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error {
deleteMap := make(map[string]struct{}) deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) { for _, layer := range append(manifest.Layers, manifest.Config) {
@ -351,7 +378,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
params := make(map[string][]string) params := make(map[string][]string)
fromParams := make(map[string]any) fromParams := make(map[string]any)
for _, c := range commands { for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"cmp"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@ -28,7 +29,6 @@ import (
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -522,28 +522,17 @@ func (s *Server) PushModelHandler(c *gin.Context) {
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
var model string name := model.ParseName(cmp.Or(req.Model, req.Name))
if req.Model != "" { if !name.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if err := ParseModelPath(model).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -552,19 +541,19 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return return
} }
var modelfile io.Reader = strings.NewReader(req.Modelfile) var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" { if req.Path != "" && req.Modelfile == "" {
mf, err := os.Open(req.Path) f, err := os.Open(req.Path)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return return
} }
defer mf.Close() defer f.Close()
modelfile = mf r = f
} }
commands, err := parser.Parse(modelfile) modelfile, err := model.ParseFile(r)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@ -580,7 +569,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil { if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -732,7 +721,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
fmt.Fprint(&sb, parser.Format(model.Commands())) fmt.Fprint(&sb, model.String())
resp.Modelfile = sb.String() resp.Modelfile = sb.String()
return resp, nil return resp, nil

View file

@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) {
createTestModel := func(t *testing.T, name string) { createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model") fname := createTestFile(t, "ollama-model")
modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
commands, err := parser.Parse(modelfile) modelfile, err := model.ParseFile(r)
assert.Nil(t, err) assert.Nil(t, err)
fn := func(resp api.ProgressResponse) { fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
err = CreateModel(context.TODO(), name, "", "", commands, fn) err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
assert.Nil(t, err) assert.Nil(t, err)
} }

View file

@ -1,4 +1,4 @@
package parser package model
import ( import (
"bufio" "bufio"
@ -10,11 +10,41 @@ import (
"strings" "strings"
) )
type File struct {
Commands []Command
}
func (f File) String() string {
var sb strings.Builder
for _, cmd := range f.Commands {
fmt.Fprintln(&sb, cmd.String())
}
return sb.String()
}
type Command struct { type Command struct {
Name string Name string
Args string Args string
} }
func (c Command) String() string {
var sb strings.Builder
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
default:
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
}
return sb.String()
}
type state int type state int
const ( const (
@ -32,38 +62,14 @@ var (
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
) )
func Format(cmds []Command) string { func ParseFile(r io.Reader) (*File, error) {
var sb strings.Builder
for _, cmd := range cmds {
name := cmd.Name
args := cmd.Args
switch cmd.Name {
case "model":
name = "from"
args = cmd.Args
case "license", "template", "system", "adapter":
args = quote(args)
case "message":
role, message, _ := strings.Cut(cmd.Args, ": ")
args = role + " " + quote(message)
default:
name = "parameter"
args = cmd.Name + " " + quote(cmd.Args)
}
fmt.Fprintln(&sb, strings.ToUpper(name), args)
}
return sb.String()
}
func Parse(r io.Reader) (cmds []Command, err error) {
var cmd Command var cmd Command
var curr state var curr state
var b bytes.Buffer var b bytes.Buffer
var role string var role string
var f File
br := bufio.NewReader(r) br := bufio.NewReader(r)
for { for {
r, _, err := br.ReadRune() r, _, err := br.ReadRune()
@ -128,7 +134,7 @@ func Parse(r io.Reader) (cmds []Command, err error) {
} }
cmd.Args = s cmd.Args = s
cmds = append(cmds, cmd) f.Commands = append(f.Commands, cmd)
} }
b.Reset() b.Reset()
@ -157,14 +163,14 @@ func Parse(r io.Reader) (cmds []Command, err error) {
} }
cmd.Args = s cmd.Args = s
cmds = append(cmds, cmd) f.Commands = append(f.Commands, cmd)
default: default:
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
for _, cmd := range cmds { for _, cmd := range f.Commands {
if cmd.Name == "model" { if cmd.Name == "model" {
return cmds, nil return &f, nil
} }
} }

View file

@ -1,4 +1,4 @@
package parser package model
import ( import (
"bytes" "bytes"
@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestParser(t *testing.T) { func TestParseFileFile(t *testing.T) {
input := ` input := `
FROM model1 FROM model1
ADAPTER adapter1 ADAPTER adapter1
@ -22,8 +22,8 @@ TEMPLATE template1
reader := strings.NewReader(input) reader := strings.NewReader(input)
commands, err := Parse(reader) modelfile, err := ParseFile(reader)
assert.Nil(t, err) assert.NoError(t, err)
expectedCommands := []Command{ expectedCommands := []Command{
{Name: "model", Args: "model1"}, {Name: "model", Args: "model1"},
@ -34,10 +34,10 @@ TEMPLATE template1
{Name: "template", Args: "template1"}, {Name: "template", Args: "template1"},
} }
assert.Equal(t, expectedCommands, commands) assert.Equal(t, expectedCommands, modelfile.Commands)
} }
func TestParserFrom(t *testing.T) { func TestParseFileFrom(t *testing.T) {
var cases = []struct { var cases = []struct {
input string input string
expected []Command expected []Command
@ -85,14 +85,16 @@ func TestParserFrom(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err) assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands) if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands)
}
}) })
} }
} }
func TestParserParametersMissingValue(t *testing.T) { func TestParseFileParametersMissingValue(t *testing.T) {
input := ` input := `
FROM foo FROM foo
PARAMETER param1 PARAMETER param1
@ -100,21 +102,21 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := Parse(reader) _, err := ParseFile(reader)
assert.ErrorIs(t, err, io.ErrUnexpectedEOF) assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
} }
func TestParserBadCommand(t *testing.T) { func TestParseFileBadCommand(t *testing.T) {
input := ` input := `
FROM foo FROM foo
BADCOMMAND param1 value1 BADCOMMAND param1 value1
` `
_, err := Parse(strings.NewReader(input)) _, err := ParseFile(strings.NewReader(input))
assert.ErrorIs(t, err, errInvalidCommand) assert.ErrorIs(t, err, errInvalidCommand)
} }
func TestParserMessages(t *testing.T) { func TestParseFileMessages(t *testing.T) {
var cases = []struct { var cases = []struct {
input string input string
expected []Command expected []Command
@ -123,34 +125,34 @@ func TestParserMessages(t *testing.T) {
{ {
` `
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a file parser. Always parse things.
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, {Name: "message", Args: "system: You are a file parser. Always parse things."},
}, },
nil, nil,
}, },
{ {
` `
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things.`, MESSAGE system You are a file parser. Always parse things.`,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, {Name: "message", Args: "system: You are a file parser. Always parse things."},
}, },
nil, nil,
}, },
{ {
` `
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a file parser. Always parse things.
MESSAGE user Hey there! MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things! MESSAGE assistant Hello, I want to parse all the things!
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, {Name: "message", Args: "system: You are a file parser. Always parse things."},
{Name: "message", Args: "user: Hey there!"}, {Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}, },
@ -160,12 +162,12 @@ MESSAGE assistant Hello, I want to parse all the things!
` `
FROM foo FROM foo
MESSAGE system """ MESSAGE system """
You are a multiline Parser. Always Parse things. You are a multiline file parser. Always parse things.
""" """
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"}, {Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
}, },
nil, nil,
}, },
@ -196,14 +198,16 @@ MESSAGE system`,
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err) assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands) if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands)
}
}) })
} }
} }
func TestParserQuoted(t *testing.T) { func TestParseFileQuoted(t *testing.T) {
var cases = []struct { var cases = []struct {
multiline string multiline string
expected []Command expected []Command
@ -348,14 +352,16 @@ TEMPLATE """
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.multiline)) modelfile, err := ParseFile(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err) assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands) if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands)
}
}) })
} }
} }
func TestParserParameters(t *testing.T) { func TestParseFileParameters(t *testing.T) {
var cases = map[string]struct { var cases = map[string]struct {
name, value string name, value string
}{ }{
@ -404,18 +410,18 @@ func TestParserParameters(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", k) fmt.Fprintln(&b, "PARAMETER", k)
commands, err := Parse(&b) modelfile, err := ParseFile(&b)
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, []Command{ assert.Equal(t, []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: v.name, Args: v.value}, {Name: v.name, Args: v.value},
}, commands) }, modelfile.Commands)
}) })
} }
} }
func TestParserComments(t *testing.T) { func TestParseFileComments(t *testing.T) {
var cases = []struct { var cases = []struct {
input string input string
expected []Command expected []Command
@ -433,14 +439,14 @@ FROM foo
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
assert.Nil(t, err) assert.NoError(t, err)
assert.Equal(t, c.expected, commands) assert.Equal(t, c.expected, modelfile.Commands)
}) })
} }
} }
func TestParseFormatParse(t *testing.T) { func TestParseFileFormatParseFile(t *testing.T) {
var cases = []string{ var cases = []string{
` `
FROM foo FROM foo
@ -449,7 +455,7 @@ LICENSE MIT
PARAMETER param1 value1 PARAMETER param1 value1
PARAMETER param2 value2 PARAMETER param2 value2
TEMPLATE template1 TEMPLATE template1
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a file parser. Always parse things.
MESSAGE user Hey there! MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things! MESSAGE assistant Hello, I want to parse all the things!
`, `,
@ -488,13 +494,13 @@ MESSAGE assistant Hello, I want to parse all the things!
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c)) modelfile, err := ParseFile(strings.NewReader(c))
assert.NoError(t, err) assert.NoError(t, err)
commands2, err := Parse(strings.NewReader(Format(commands))) modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commands, commands2) assert.Equal(t, modelfile, modelfile2)
}) })
} }