diff --git a/cmd/cmd.go b/cmd/cmd.go index eb6ae76f..afae9d90 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -57,12 +57,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() - modelfile, err := os.ReadFile(filename) + modelfile, err := os.Open(filename) if err != nil { return err } + defer modelfile.Close() - commands, err := parser.Parse(bytes.NewReader(modelfile)) + commands, err := parser.Parse(modelfile) if err != nil { return err } @@ -76,10 +77,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner := progress.NewSpinner(status) p.Add(status, spinner) - for _, c := range commands { - switch c.Name { + for i := range commands { + switch commands[i].Name { case "model", "adapter": - path := c.Args + path := commands[i].Args if path == "~" { path = home } else if strings.HasPrefix(path, "~/") { @@ -91,7 +92,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } fi, err := os.Stat(path) - if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" { continue } else if err != nil { return err @@ -114,13 +115,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - name := c.Name - if c.Name == "model" { - name = "from" - } - - re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args)) - modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest)) + commands[i].Args = "@"+digest } } @@ -150,7 +145,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { quantization, _ := cmd.Flags().GetString("quantization") - request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} + request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization} if err := client.Create(cmd.Context(), &request, fn); err != nil { return err } diff --git a/parser/parser.go b/parser/parser.go index 947848b2..9d1f3388 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -6,8 +6,8 @@ import ( "errors" "fmt" "io" - "log/slog" - "slices" + "strconv" + "strings" ) type Command struct { @@ -15,118 +15,283 @@ type Command struct { Args string } -func (c *Command) Reset() { - c.Name = "" - c.Args = "" -} +type state int -func Parse(reader io.Reader) ([]Command, error) { - var commands []Command - var command, modelCommand Command +const ( + stateNil state = iota + stateName + stateValue + stateParameter + stateMessage + stateComment +) - scanner := bufio.NewScanner(reader) - scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) - scanner.Split(scanModelfile) - for scanner.Scan() { - line := scanner.Bytes() +var ( + errMissingFrom = errors.New("no FROM line") + errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") +) - fields := bytes.SplitN(line, []byte(" "), 2) - if len(fields) == 0 || len(fields[0]) == 0 { - continue - } +func Format(cmds []Command) string { + var sb strings.Builder + for _, cmd := range cmds { + name := cmd.Name + args := cmd.Args - switch string(bytes.ToUpper(fields[0])) { - case "FROM": - command.Name = "model" - command.Args = string(bytes.TrimSpace(fields[1])) - // copy command for validation - modelCommand = command - case "ADAPTER": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(bytes.TrimSpace(fields[1])) - case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(fields[1]) - case "PARAMETER": - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("missing value for %s", fields) - } - - command.Name = string(fields[0]) - command.Args = string(bytes.TrimSpace(fields[1])) - case "EMBED": - return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") - case "MESSAGE": - command.Name = string(bytes.ToLower(fields[0])) - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("should be in the format ") - } - if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) { - return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"") - } - command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1])) + switch cmd.Name { + case "model": + name = "from" + args = cmd.Args + case "license", "template", "system", "adapter": + args = quote(args) + case "message": + role, message, _ := strings.Cut(cmd.Args, ": ") + args = role + " " + quote(message) default: - if !bytes.HasPrefix(fields[0], []byte("#")) { - // log a warning for unknown commands - slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0])) - } - continue + name = "parameter" + args = cmd.Name + " " + quote(cmd.Args) } - commands = append(commands, command) - command.Reset() + fmt.Fprintln(&sb, strings.ToUpper(name), args) } - if modelCommand.Args == "" { - return nil, errors.New("no FROM line for the model was specified") - } - - return commands, scanner.Err() + return sb.String() } -func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { - advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) - if err != nil { - return 0, nil, err - } +func Parse(r io.Reader) (cmds []Command, err error) { + var cmd Command + var curr state + var b bytes.Buffer + var role string - if advance > 0 && token != nil { - return advance, token, nil - } - - advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - return bufio.ScanLines(data, atEOF) -} - -func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { - newline := bytes.IndexByte(data, '\n') - - if start := bytes.Index(data, openBytes); start >= 0 && start < newline { - end := bytes.Index(data[start+len(openBytes):], closeBytes) - if end < 0 { - if atEOF { - return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes) - } else { - return 0, nil, nil - } + br := bufio.NewReader(r) + for { + r, _, err := br.ReadRune() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err } - n := start + len(openBytes) + end + len(closeBytes) + next, r, err := parseRuneForState(r, curr) + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, fmt.Errorf("%w: %s", err, b.String()) + } else if err != nil { + return nil, err + } - newData := data[:start] - newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) - return n, newData, nil + // process the state transition, some transitions need to be intercepted and redirected + if next != curr { + switch curr { + case stateName: + if !isValidCommand(b.String()) { + return nil, errInvalidCommand + } + + // next state sometimes depends on the current buffer value + switch s := strings.ToLower(b.String()); s { + case "from": + cmd.Name = "model" + case "parameter": + // transition to stateParameter which sets command name + next = stateParameter + case "message": + // transition to stateMessage which validates the message role + next = stateMessage + fallthrough + default: + cmd.Name = s + } + case stateParameter: + cmd.Name = b.String() + case stateMessage: + if !isValidMessageRole(b.String()) { + return nil, errInvalidMessageRole + } + + role = b.String() + case stateComment, stateNil: + // pass + case stateValue: + s, ok := unquote(b.String()) + if !ok || isSpace(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + + continue + } + + if role != "" { + s = role + ": " + s + role = "" + } + + cmd.Args = s + cmds = append(cmds, cmd) + } + + b.Reset() + curr = next + } + + if strconv.IsPrint(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + } } - return 0, nil, nil + // flush the buffer + switch curr { + case stateComment, stateNil: + // pass; nothing to flush + case stateValue: + s, ok := unquote(b.String()) + if !ok { + return nil, io.ErrUnexpectedEOF + } + + if role != "" { + s = role + ": " + s + } + + cmd.Args = s + cmds = append(cmds, cmd) + default: + return nil, io.ErrUnexpectedEOF + } + + for _, cmd := range cmds { + if cmd.Name == "model" { + return cmds, nil + } + } + + return nil, errMissingFrom +} + +func parseRuneForState(r rune, cs state) (state, rune, error) { + switch cs { + case stateNil: + switch { + case r == '#': + return stateComment, 0, nil + case isSpace(r), isNewline(r): + return stateNil, 0, nil + default: + return stateName, r, nil + } + case stateName: + switch { + case isAlpha(r): + return stateName, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, errInvalidCommand + } + case stateValue: + switch { + case isNewline(r): + return stateNil, r, nil + case isSpace(r): + return stateNil, r, nil + default: + return stateValue, r, nil + } + case stateParameter: + switch { + case isAlpha(r), isNumber(r), r == '_': + return stateParameter, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateMessage: + switch { + case isAlpha(r): + return stateMessage, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateComment: + switch { + case isNewline(r): + return stateNil, 0, nil + default: + return stateComment, 0, nil + } + default: + return stateNil, 0, errors.New("") + } +} + +func quote(s string) string { + if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") { + if strings.Contains(s, "\"") { + return `"""` + s + `"""` + } + + return `"` + s + `"` + } + + return s +} + +func unquote(s string) (string, bool) { + if len(s) == 0 { + return "", false + } + + // TODO: single quotes + if len(s) >= 3 && s[:3] == `"""` { + if len(s) >= 6 && s[len(s)-3:] == `"""` { + return s[3 : len(s)-3], true + } + + return "", false + } + + if len(s) >= 1 && s[0] == '"' { + if len(s) >= 2 && s[len(s)-1] == '"' { + return s[1 : len(s)-1], true + } + + return "", false + } + + return s, true +} + +func isAlpha(r rune) bool { + return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' +} + +func isNumber(r rune) bool { + return r >= '0' && r <= '9' +} + +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +func isNewline(r rune) bool { + return r == '\r' || r == '\n' +} + +func isValidMessageRole(role string) bool { + return role == "system" || role == "user" || role == "assistant" +} + +func isValidCommand(cmd string) bool { + switch strings.ToLower(cmd) { + case "from", "license", "template", "system", "adapter", "parameter", "message": + return true + default: + return false + } } diff --git a/parser/parser_test.go b/parser/parser_test.go index 25e849b5..a28205aa 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,14 +1,16 @@ package parser import ( + "bytes" + "fmt" + "io" "strings" "testing" "github.com/stretchr/testify/assert" ) -func Test_Parser(t *testing.T) { - +func TestParser(t *testing.T) { input := ` FROM model1 ADAPTER adapter1 @@ -35,21 +37,62 @@ TEMPLATE template1 assert.Equal(t, expectedCommands, commands) } -func Test_Parser_NoFromLine(t *testing.T) { +func TestParserFrom(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + "FROM foo", + []Command{{Name: "model", Args: "foo"}}, + nil, + }, + { + "FROM /path/to/model", + []Command{{Name: "model", Args: "/path/to/model"}}, + nil, + }, + { + "FROM /path/to/model/fp16.bin", + []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}}, + nil, + }, + { + "FROM llama3:latest", + []Command{{Name: "model", Args: "llama3:latest"}}, + nil, + }, + { + "FROM llama3:7b-instruct-q4_K_M", + []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}}, + nil, + }, + { + "", nil, errMissingFrom, + }, + { + "PARAMETER param1 value1", + nil, + errMissingFrom, + }, + { + "PARAMETER param1 value1\nFROM foo", + []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, + nil, + }, + } - input := ` -PARAMETER param1 value1 -PARAMETER param2 value2 -` - - reader := strings.NewReader(input) - - _, err := Parse(reader) - assert.ErrorContains(t, err, "no FROM line") + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } } -func Test_Parser_MissingValue(t *testing.T) { - +func TestParserParametersMissingValue(t *testing.T) { input := ` FROM foo PARAMETER param1 @@ -58,41 +101,401 @@ PARAMETER param1 reader := strings.NewReader(input) _, err := Parse(reader) - assert.ErrorContains(t, err, "missing value for [param1]") + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) +} + +func TestParserBadCommand(t *testing.T) { + input := ` +FROM foo +BADCOMMAND param1 value1 +` + _, err := Parse(strings.NewReader(input)) + assert.ErrorIs(t, err, errInvalidCommand) } -func Test_Parser_Messages(t *testing.T) { - - input := ` +func TestParserMessages(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + ` +FROM foo +MESSAGE system You are a Parser. Always Parse things. +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system You are a Parser. Always Parse things.`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + }, + nil, + }, + { + ` FROM foo MESSAGE system You are a Parser. Always Parse things. MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! -` - - reader := strings.NewReader(input) - commands, err := Parse(reader) - assert.Nil(t, err) - - expectedCommands := []Command{ - {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, - {Name: "message", Args: "user: Hey there!"}, - {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, - } - - assert.Equal(t, expectedCommands, commands) -} - -func Test_Parser_Messages_BadRole(t *testing.T) { - - input := ` +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "user: Hey there!"}, + {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system """ +You are a multiline Parser. Always Parse things. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"}, + }, + nil, + }, + { + ` FROM foo MESSAGE badguy I'm a bad guy! -` +`, + nil, + errInvalidMessageRole, + }, + { + ` +FROM foo +MESSAGE system +`, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +MESSAGE system`, + nil, + io.ErrUnexpectedEOF, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParserQuoted(t *testing.T) { + var cases = []struct { + multiline string + expected []Command + err error + }{ + { + ` +FROM foo +SYSTEM """ +This is a +multiline system. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "\nThis is a\nmultiline system.\n"}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """ +This is a +multiline system.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "\nThis is a\nmultiline system."}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """This is a +multiline system.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "This is a\nmultiline system."}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """This is a multiline system.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "This is a multiline system."}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """This is a multiline system."" + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +SYSTEM " + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +SYSTEM """ +This is a multiline system with "quotes". +""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM "" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM "'" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "'"}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """''"'""'""'"'''''""'""'""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: `''"'""'""'"'''''""'""'`}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """ +{{ .Prompt }} +"""`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\n{{ .Prompt }}\n"}, + }, + nil, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.multiline)) + assert.ErrorIs(t, err, c.err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParserParameters(t *testing.T) { + var cases = map[string]struct { + name, value string + }{ + "numa true": {"numa", "true"}, + "num_ctx 1": {"num_ctx", "1"}, + "num_batch 1": {"num_batch", "1"}, + "num_gqa 1": {"num_gqa", "1"}, + "num_gpu 1": {"num_gpu", "1"}, + "main_gpu 1": {"main_gpu", "1"}, + "low_vram true": {"low_vram", "true"}, + "f16_kv true": {"f16_kv", "true"}, + "logits_all true": {"logits_all", "true"}, + "vocab_only true": {"vocab_only", "true"}, + "use_mmap true": {"use_mmap", "true"}, + "use_mlock true": {"use_mlock", "true"}, + "num_thread 1": {"num_thread", "1"}, + "num_keep 1": {"num_keep", "1"}, + "seed 1": {"seed", "1"}, + "num_predict 1": {"num_predict", "1"}, + "top_k 1": {"top_k", "1"}, + "top_p 1.0": {"top_p", "1.0"}, + "tfs_z 1.0": {"tfs_z", "1.0"}, + "typical_p 1.0": {"typical_p", "1.0"}, + "repeat_last_n 1": {"repeat_last_n", "1"}, + "temperature 1.0": {"temperature", "1.0"}, + "repeat_penalty 1.0": {"repeat_penalty", "1.0"}, + "presence_penalty 1.0": {"presence_penalty", "1.0"}, + "frequency_penalty 1.0": {"frequency_penalty", "1.0"}, + "mirostat 1": {"mirostat", "1"}, + "mirostat_tau 1.0": {"mirostat_tau", "1.0"}, + "mirostat_eta 1.0": {"mirostat_eta", "1.0"}, + "penalize_newline true": {"penalize_newline", "true"}, + "stop ### User:": {"stop", "### User:"}, + "stop ### User: ": {"stop", "### User: "}, + "stop \"### User:\"": {"stop", "### User:"}, + "stop \"### User: \"": {"stop", "### User: "}, + "stop \"\"\"### User:\"\"\"": {"stop", "### User:"}, + "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"}, + "stop <|endoftext|>": {"stop", "<|endoftext|>"}, + "stop <|eot_id|>": {"stop", "<|eot_id|>"}, + "stop ": {"stop", ""}, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + var b bytes.Buffer + fmt.Fprintln(&b, "FROM foo") + fmt.Fprintln(&b, "PARAMETER", k) + commands, err := Parse(&b) + assert.Nil(t, err) + + assert.Equal(t, []Command{ + {Name: "model", Args: "foo"}, + {Name: v.name, Args: v.value}, + }, commands) + }) + } +} + +func TestParserComments(t *testing.T) { + var cases = []struct { + input string + expected []Command + }{ + { + ` +# comment +FROM foo + `, + []Command{ + {Name: "model", Args: "foo"}, + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c.input)) + assert.Nil(t, err) + assert.Equal(t, c.expected, commands) + }) + } +} + +func TestParseFormatParse(t *testing.T) { + var cases = []string{ + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system You are a Parser. Always Parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE """ +Very long and boring legal text. +Blah blah blah. +"Oh look, a quote!" +""" + +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + commands, err := Parse(strings.NewReader(c)) + assert.NoError(t, err) + + commands2, err := Parse(strings.NewReader(Format(commands))) + assert.NoError(t, err) + + assert.Equal(t, commands, commands2) + }) + } - reader := strings.NewReader(input) - _, err := Parse(reader) - assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") } diff --git a/server/images.go b/server/images.go index 4e4107f7..68840c1a 100644 --- a/server/images.go +++ b/server/images.go @@ -21,7 +21,6 @@ import ( "runtime" "strconv" "strings" - "text/template" "golang.org/x/exp/slices" @@ -64,6 +63,48 @@ func (m *Model) IsEmbedding() bool { return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") } +func (m *Model) Commands() (cmds []parser.Command) { + cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath}) + + if m.Template != "" { + cmds = append(cmds, parser.Command{Name: "template", Args: m.Template}) + } + + if m.System != "" { + cmds = append(cmds, parser.Command{Name: "system", Args: m.System}) + } + + for _, adapter := range m.AdapterPaths { + cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter}) + } + + for _, projector := range m.ProjectorPaths { + cmds = append(cmds, parser.Command{Name: "projector", Args: projector}) + } + + for k, v := range m.Options { + switch v := v.(type) { + case []any: + for _, s := range v { + cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)}) + } + default: + cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)}) + } + } + + for _, license := range m.License { + cmds = append(cmds, parser.Command{Name: "license", Args: license}) + } + + for _, msg := range m.Messages { + cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)}) + } + + return cmds + +} + type Message struct { Role string `json:"role"` Content string `json:"content"` @@ -901,67 +942,6 @@ func DeleteModel(name string) error { return nil } -func ShowModelfile(model *Model) (string, error) { - var mt struct { - *Model - From string - Parameters map[string][]any - } - - mt.Parameters = make(map[string][]any) - for k, v := range model.Options { - if s, ok := v.([]any); ok { - mt.Parameters[k] = s - continue - } - - mt.Parameters[k] = []any{v} - } - - mt.Model = model - mt.From = model.ModelPath - - if model.ParentModel != "" { - mt.From = model.ParentModel - } - - modelFile := `# Modelfile generated by "ollama show" -# To build a new Modelfile based on this one, replace the FROM line with: -# FROM {{ .ShortName }} - -FROM {{ .From }} -TEMPLATE """{{ .Template }}""" - -{{- if .System }} -SYSTEM """{{ .System }}""" -{{- end }} - -{{- range $adapter := .AdapterPaths }} -ADAPTER {{ $adapter }} -{{- end }} - -{{- range $k, $v := .Parameters }} -{{- range $parameter := $v }} -PARAMETER {{ $k }} {{ printf "%#v" $parameter }} -{{- end }} -{{- end }}` - - tmpl, err := template.New("").Parse(modelFile) - if err != nil { - slog.Info(fmt.Sprintf("error parsing template: %q", err)) - return "", err - } - - var buf bytes.Buffer - - if err = tmpl.Execute(&buf, mt); err != nil { - slog.Info(fmt.Sprintf("error executing template: %q", err)) - return "", err - } - - return buf.String(), nil -} - func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) diff --git a/server/routes.go b/server/routes.go index 917bb2ef..480527f2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -728,12 +728,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } } - mf, err := ShowModelfile(model) - if err != nil { - return nil, err - } - - resp.Modelfile = mf + var sb strings.Builder + fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") + fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") + fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) + fmt.Fprint(&sb, parser.Format(model.Commands())) + resp.Modelfile = sb.String() return resp, nil } diff --git a/server/routes_test.go b/server/routes_test.go index 4f907702..6ac98367 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) { if tc.Expected != nil { tc.Expected(t, resp) } - } }