From 119589fcb3b9568d3536d9011ca607286bf66e81 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 30 Apr 2024 10:55:19 -0700 Subject: [PATCH 1/2] rename parser to model/file --- cmd/cmd.go | 19 ++--- server/images.go | 57 +++++++++---- server/routes.go | 37 +++----- server/routes_test.go | 8 +- parser/parser.go => types/model/file.go | 74 +++++++++------- .../model/file_test.go | 84 ++++++++++--------- 6 files changed, 155 insertions(+), 124 deletions(-) rename parser/parser.go => types/model/file.go (86%) rename parser/parser_test.go => types/model/file_test.go (80%) diff --git a/cmd/cmd.go b/cmd/cmd.go index fa3172ca..19198c02 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/errtypes" @@ -57,13 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() - modelfile, err := os.Open(filename) + f, err := os.Open(filename) if err != nil { return err } - defer modelfile.Close() + defer f.Close() - commands, err := parser.Parse(modelfile) + modelfile, err := model.ParseFile(f) if err != nil { return err } @@ -77,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner := progress.NewSpinner(status) p.Add(status, spinner) - for i := range commands { - switch commands[i].Name { + for i := range modelfile.Commands { + switch modelfile.Commands[i].Name { case "model", "adapter": - path := commands[i].Args + path := modelfile.Commands[i].Args if path == "~" { path = home } else if strings.HasPrefix(path, "~/") { @@ -92,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } fi, err := os.Stat(path) - if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" { + if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" { continue } else if err != nil { return err @@ -115,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - commands[i].Args = "@"+digest + modelfile.Commands[i].Args = "@" + digest } } @@ -145,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { quantization, _ := cmd.Flags().GetString("quantization") - request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization} + request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization} if err := client.Create(cmd.Context(), &request, fn); err != nil { return err } diff --git a/server/images.go b/server/images.go index 68840c1a..75a41d4a 100644 --- a/server/images.go +++ b/server/images.go @@ -29,7 +29,6 @@ import ( "github.com/ollama/ollama/convert" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -63,46 +62,74 @@ func (m *Model) IsEmbedding() bool { return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") } -func (m *Model) Commands() (cmds []parser.Command) { - cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath}) +func (m *Model) String() string { + var modelfile model.File + + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "model", + Args: m.ModelPath, + }) if m.Template != "" { - cmds = append(cmds, parser.Command{Name: "template", Args: m.Template}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "template", + Args: m.Template, + }) } if m.System != "" { - cmds = append(cmds, parser.Command{Name: "system", Args: m.System}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "system", + Args: m.System, + }) } for _, adapter := range m.AdapterPaths { - cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "adapter", + Args: adapter, + }) } for _, projector := range m.ProjectorPaths { - cmds = append(cmds, parser.Command{Name: "projector", Args: projector}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "projector", + Args: projector, + }) } for k, v := range m.Options { switch v := v.(type) { case []any: for _, s := range v { - cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: k, + Args: fmt.Sprintf("%v", s), + }) } default: - cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: k, + Args: fmt.Sprintf("%v", v), + }) } } for _, license := range m.License { - cmds = append(cmds, parser.Command{Name: "license", Args: license}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "license", + Args: license, + }) } for _, msg := range m.Messages { - cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)}) + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "message", + Args: fmt.Sprintf("%s %s", msg.Role, msg.Content), + }) } - return cmds - + return modelfile.String() } type Message struct { @@ -329,7 +356,7 @@ func realpath(mfDir, from string) string { return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { +func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error { deleteMap := make(map[string]struct{}) if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { for _, layer := range append(manifest.Layers, manifest.Config) { @@ -351,7 +378,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c params := make(map[string][]string) fromParams := make(map[string]any) - for _, c := range commands { + for _, c := range modelfile.Commands { mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) switch c.Name { diff --git a/server/routes.go b/server/routes.go index 35b20f56..f4b08272 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "cmp" "context" "encoding/json" "errors" @@ -28,7 +29,6 @@ import ( "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -522,28 +522,17 @@ func (s *Server) PushModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) { var req api.CreateRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - - if err := ParseModelPath(model).Validate(); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + name := model.ParseName(cmp.Or(req.Model, req.Name)) + if !name.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) return } @@ -552,19 +541,19 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } - var modelfile io.Reader = strings.NewReader(req.Modelfile) + var r io.Reader = strings.NewReader(req.Modelfile) if req.Path != "" && req.Modelfile == "" { - mf, err := os.Open(req.Path) + f, err := os.Open(req.Path) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) return } - defer mf.Close() + defer f.Close() - modelfile = mf + r = f } - commands, err := parser.Parse(modelfile) + modelfile, err := model.ParseFile(r) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -580,7 +569,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil { + if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -732,7 +721,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) - fmt.Fprint(&sb, parser.Format(model.Commands())) + fmt.Fprint(&sb, model.String()) resp.Modelfile = sb.String() return resp, nil diff --git a/server/routes_test.go b/server/routes_test.go index 6ac98367..27e53cbd 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) { createTestModel := func(t *testing.T, name string) { fname := createTestFile(t, "ollama-model") - modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) - commands, err := parser.Parse(modelfile) + r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) + modelfile, err := model.ParseFile(r) assert.Nil(t, err) fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } - err = CreateModel(context.TODO(), name, "", "", commands, fn) + err = CreateModel(context.TODO(), name, "", "", modelfile, fn) assert.Nil(t, err) } diff --git a/parser/parser.go b/types/model/file.go similarity index 86% rename from parser/parser.go rename to types/model/file.go index 9d1f3388..b4b7578f 100644 --- a/parser/parser.go +++ b/types/model/file.go @@ -1,4 +1,4 @@ -package parser +package model import ( "bufio" @@ -10,11 +10,45 @@ import ( "strings" ) +type File struct { + Commands []Command +} + +func (f File) String() string { + var sb strings.Builder + for _, cmd := range f.Commands { + fmt.Fprintln(&sb, cmd.String()) + } + + return sb.String() +} + type Command struct { Name string Args string } +func (c Command) String() string { + name := c.Name + args := c.Args + + switch c.Name { + case "model": + name = "from" + args = c.Args + case "license", "template", "system", "adapter": + args = quote(args) + case "message": + role, message, _ := strings.Cut(c.Args, ": ") + args = role + " " + quote(message) + default: + name = "parameter" + args = c.Name + " " + quote(c.Args) + } + + return fmt.Sprintf("%s %s", strings.ToUpper(name), args) +} + type state int const ( @@ -32,38 +66,14 @@ var ( errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") ) -func Format(cmds []Command) string { - var sb strings.Builder - for _, cmd := range cmds { - name := cmd.Name - args := cmd.Args - - switch cmd.Name { - case "model": - name = "from" - args = cmd.Args - case "license", "template", "system", "adapter": - args = quote(args) - case "message": - role, message, _ := strings.Cut(cmd.Args, ": ") - args = role + " " + quote(message) - default: - name = "parameter" - args = cmd.Name + " " + quote(cmd.Args) - } - - fmt.Fprintln(&sb, strings.ToUpper(name), args) - } - - return sb.String() -} - -func Parse(r io.Reader) (cmds []Command, err error) { +func ParseFile(r io.Reader) (*File, error) { var cmd Command var curr state var b bytes.Buffer var role string + var f File + br := bufio.NewReader(r) for { r, _, err := br.ReadRune() @@ -128,7 +138,7 @@ func Parse(r io.Reader) (cmds []Command, err error) { } cmd.Args = s - cmds = append(cmds, cmd) + f.Commands = append(f.Commands, cmd) } b.Reset() @@ -157,14 +167,14 @@ func Parse(r io.Reader) (cmds []Command, err error) { } cmd.Args = s - cmds = append(cmds, cmd) + f.Commands = append(f.Commands, cmd) default: return nil, io.ErrUnexpectedEOF } - for _, cmd := range cmds { + for _, cmd := range f.Commands { if cmd.Name == "model" { - return cmds, nil + return &f, nil } } diff --git a/parser/parser_test.go b/types/model/file_test.go similarity index 80% rename from parser/parser_test.go rename to types/model/file_test.go index a28205aa..d51c8d70 100644 --- a/parser/parser_test.go +++ b/types/model/file_test.go @@ -1,4 +1,4 @@ -package parser +package model import ( "bytes" @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestParser(t *testing.T) { +func TestParseFileFile(t *testing.T) { input := ` FROM model1 ADAPTER adapter1 @@ -22,8 +22,8 @@ TEMPLATE template1 reader := strings.NewReader(input) - commands, err := Parse(reader) - assert.Nil(t, err) + modelfile, err := ParseFile(reader) + assert.NoError(t, err) expectedCommands := []Command{ {Name: "model", Args: "model1"}, @@ -34,10 +34,10 @@ TEMPLATE template1 {Name: "template", Args: "template1"}, } - assert.Equal(t, expectedCommands, commands) + assert.Equal(t, expectedCommands, modelfile.Commands) } -func TestParserFrom(t *testing.T) { +func TestParseFileFrom(t *testing.T) { var cases = []struct { input string expected []Command @@ -85,14 +85,16 @@ func TestParserFrom(t *testing.T) { for _, c := range cases { t.Run("", func(t *testing.T) { - commands, err := Parse(strings.NewReader(c.input)) + modelfile, err := ParseFile(strings.NewReader(c.input)) assert.ErrorIs(t, err, c.err) - assert.Equal(t, c.expected, commands) + if modelfile != nil { + assert.Equal(t, c.expected, modelfile.Commands) + } }) } } -func TestParserParametersMissingValue(t *testing.T) { +func TestParseFileParametersMissingValue(t *testing.T) { input := ` FROM foo PARAMETER param1 @@ -100,21 +102,21 @@ PARAMETER param1 reader := strings.NewReader(input) - _, err := Parse(reader) + _, err := ParseFile(reader) assert.ErrorIs(t, err, io.ErrUnexpectedEOF) } -func TestParserBadCommand(t *testing.T) { +func TestParseFileBadCommand(t *testing.T) { input := ` FROM foo BADCOMMAND param1 value1 ` - _, err := Parse(strings.NewReader(input)) + _, err := ParseFile(strings.NewReader(input)) assert.ErrorIs(t, err, errInvalidCommand) } -func TestParserMessages(t *testing.T) { +func TestParseFileMessages(t *testing.T) { var cases = []struct { input string expected []Command @@ -123,34 +125,34 @@ func TestParserMessages(t *testing.T) { { ` FROM foo -MESSAGE system You are a Parser. Always Parse things. +MESSAGE system You are a file parser. Always parse things. `, []Command{ {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "system: You are a file parser. Always parse things."}, }, nil, }, { ` FROM foo -MESSAGE system You are a Parser. Always Parse things.`, +MESSAGE system You are a file parser. Always parse things.`, []Command{ {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "system: You are a file parser. Always parse things."}, }, nil, }, { ` FROM foo -MESSAGE system You are a Parser. Always Parse things. +MESSAGE system You are a file parser. Always parse things. MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! `, []Command{ {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: "user: Hey there!"}, {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, }, @@ -160,12 +162,12 @@ MESSAGE assistant Hello, I want to parse all the things! ` FROM foo MESSAGE system """ -You are a multiline Parser. Always Parse things. +You are a multiline file parser. Always parse things. """ `, []Command{ {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"}, + {Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"}, }, nil, }, @@ -196,14 +198,16 @@ MESSAGE system`, for _, c := range cases { t.Run("", func(t *testing.T) { - commands, err := Parse(strings.NewReader(c.input)) + modelfile, err := ParseFile(strings.NewReader(c.input)) assert.ErrorIs(t, err, c.err) - assert.Equal(t, c.expected, commands) + if modelfile != nil { + assert.Equal(t, c.expected, modelfile.Commands) + } }) } } -func TestParserQuoted(t *testing.T) { +func TestParseFileQuoted(t *testing.T) { var cases = []struct { multiline string expected []Command @@ -348,14 +352,16 @@ TEMPLATE """ for _, c := range cases { t.Run("", func(t *testing.T) { - commands, err := Parse(strings.NewReader(c.multiline)) + modelfile, err := ParseFile(strings.NewReader(c.multiline)) assert.ErrorIs(t, err, c.err) - assert.Equal(t, c.expected, commands) + if modelfile != nil { + assert.Equal(t, c.expected, modelfile.Commands) + } }) } } -func TestParserParameters(t *testing.T) { +func TestParseFileParameters(t *testing.T) { var cases = map[string]struct { name, value string }{ @@ -404,18 +410,18 @@ func TestParserParameters(t *testing.T) { var b bytes.Buffer fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "PARAMETER", k) - commands, err := Parse(&b) - assert.Nil(t, err) + modelfile, err := ParseFile(&b) + assert.NoError(t, err) assert.Equal(t, []Command{ {Name: "model", Args: "foo"}, {Name: v.name, Args: v.value}, - }, commands) + }, modelfile.Commands) }) } } -func TestParserComments(t *testing.T) { +func TestParseFileComments(t *testing.T) { var cases = []struct { input string expected []Command @@ -433,14 +439,14 @@ FROM foo for _, c := range cases { t.Run("", func(t *testing.T) { - commands, err := Parse(strings.NewReader(c.input)) - assert.Nil(t, err) - assert.Equal(t, c.expected, commands) + modelfile, err := ParseFile(strings.NewReader(c.input)) + assert.NoError(t, err) + assert.Equal(t, c.expected, modelfile.Commands) }) } } -func TestParseFormatParse(t *testing.T) { +func TestParseFileFormatParseFile(t *testing.T) { var cases = []string{ ` FROM foo @@ -449,7 +455,7 @@ LICENSE MIT PARAMETER param1 value1 PARAMETER param2 value2 TEMPLATE template1 -MESSAGE system You are a Parser. Always Parse things. +MESSAGE system You are a file parser. Always parse things. MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! `, @@ -488,13 +494,13 @@ MESSAGE assistant Hello, I want to parse all the things! for _, c := range cases { t.Run("", func(t *testing.T) { - commands, err := Parse(strings.NewReader(c)) + modelfile, err := ParseFile(strings.NewReader(c)) assert.NoError(t, err) - commands2, err := Parse(strings.NewReader(Format(commands))) + modelfile2, err := ParseFile(strings.NewReader(modelfile.String())) assert.NoError(t, err) - assert.Equal(t, commands, commands2) + assert.Equal(t, modelfile, modelfile2) }) } From 8acb233668dd14a42bd8c6dc9ee3e85544d29bca Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 1 May 2024 10:01:09 -0700 Subject: [PATCH 2/2] use strings.Builder --- types/model/file.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/types/model/file.go b/types/model/file.go index b4b7578f..c614fd32 100644 --- a/types/model/file.go +++ b/types/model/file.go @@ -29,24 +29,20 @@ type Command struct { } func (c Command) String() string { - name := c.Name - args := c.Args - + var sb strings.Builder switch c.Name { case "model": - name = "from" - args = c.Args + fmt.Fprintf(&sb, "FROM %s", c.Args) case "license", "template", "system", "adapter": - args = quote(args) + fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") - args = role + " " + quote(message) + fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message)) default: - name = "parameter" - args = c.Name + " " + quote(c.Args) + fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args)) } - return fmt.Sprintf("%s %s", strings.ToUpper(name), args) + return sb.String() } type state int