Merge pull request #3892 from ollama/mxyng/parser

refactor modelfile parser
This commit is contained in:
Michael Yang 2024-05-02 17:04:47 -07:00 committed by GitHub
commit e9ae607ece
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 765 additions and 223 deletions

View file

@ -57,12 +57,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
modelfile, err := os.ReadFile(filename) modelfile, err := os.Open(filename)
if err != nil { if err != nil {
return err return err
} }
defer modelfile.Close()
commands, err := parser.Parse(bytes.NewReader(modelfile)) commands, err := parser.Parse(modelfile)
if err != nil { if err != nil {
return err return err
} }
@ -76,10 +77,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
for _, c := range commands { for i := range commands {
switch c.Name { switch commands[i].Name {
case "model", "adapter": case "model", "adapter":
path := c.Args path := commands[i].Args
if path == "~" { if path == "~" {
path = home path = home
} else if strings.HasPrefix(path, "~/") { } else if strings.HasPrefix(path, "~/") {
@ -91,7 +92,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
fi, err := os.Stat(path) fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" { if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
continue continue
} else if err != nil { } else if err != nil {
return err return err
@ -114,13 +115,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
name := c.Name commands[i].Args = "@"+digest
if c.Name == "model" {
name = "from"
}
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
} }
} }
@ -150,7 +145,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
quantization, _ := cmd.Flags().GetString("quantization") quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil { if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err return err
} }

View file

@ -6,8 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog" "strconv"
"slices" "strings"
) )
type Command struct { type Command struct {
@ -15,118 +15,283 @@ type Command struct {
Args string Args string
} }
func (c *Command) Reset() { type state int
c.Name = ""
c.Args = ""
}
func Parse(reader io.Reader) ([]Command, error) { const (
var commands []Command stateNil state = iota
var command, modelCommand Command stateName
stateValue
stateParameter
stateMessage
stateComment
)
scanner := bufio.NewScanner(reader) var (
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) errMissingFrom = errors.New("no FROM line")
scanner.Split(scanModelfile) errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
for scanner.Scan() { errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
line := scanner.Bytes() )
fields := bytes.SplitN(line, []byte(" "), 2) func Format(cmds []Command) string {
if len(fields) == 0 || len(fields[0]) == 0 { var sb strings.Builder
continue for _, cmd := range cmds {
} name := cmd.Name
args := cmd.Args
switch string(bytes.ToUpper(fields[0])) { switch cmd.Name {
case "FROM": case "model":
command.Name = "model" name = "from"
command.Args = string(bytes.TrimSpace(fields[1])) args = cmd.Args
// copy command for validation case "license", "template", "system", "adapter":
modelCommand = command args = quote(args)
case "ADAPTER": case "message":
command.Name = string(bytes.ToLower(fields[0])) role, message, _ := strings.Cut(cmd.Args, ": ")
command.Args = string(bytes.TrimSpace(fields[1])) args = role + " " + quote(message)
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("missing value for %s", fields)
}
command.Name = string(fields[0])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { name = "parameter"
// log a warning for unknown commands args = cmd.Name + " " + quote(cmd.Args)
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
}
continue
} }
commands = append(commands, command) fmt.Fprintln(&sb, strings.ToUpper(name), args)
command.Reset()
} }
if modelCommand.Args == "" { return sb.String()
return nil, errors.New("no FROM line for the model was specified")
}
return commands, scanner.Err()
} }
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { func Parse(r io.Reader) (cmds []Command, err error) {
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) var cmd Command
if err != nil { var curr state
return 0, nil, err var b bytes.Buffer
} var role string
if advance > 0 && token != nil { br := bufio.NewReader(r)
return advance, token, nil for {
} r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) break
if err != nil { } else if err != nil {
return 0, nil, err return nil, err
}
if advance > 0 && token != nil {
return advance, token, nil
}
return bufio.ScanLines(data, atEOF)
}
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
newline := bytes.IndexByte(data, '\n')
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
end := bytes.Index(data[start+len(openBytes):], closeBytes)
if end < 0 {
if atEOF {
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
} else {
return 0, nil, nil
}
} }
n := start + len(openBytes) + end + len(closeBytes) next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String())
} else if err != nil {
return nil, err
}
newData := data[:start] // process the state transition, some transitions need to be intercepted and redirected
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) if next != curr {
return n, newData, nil switch curr {
case stateName:
if !isValidCommand(b.String()) {
return nil, errInvalidCommand
}
// next state sometimes depends on the current buffer value
switch s := strings.ToLower(b.String()); s {
case "from":
cmd.Name = "model"
case "parameter":
// transition to stateParameter which sets command name
next = stateParameter
case "message":
// transition to stateMessage which validates the message role
next = stateMessage
fallthrough
default:
cmd.Name = s
}
case stateParameter:
cmd.Name = b.String()
case stateMessage:
if !isValidMessageRole(b.String()) {
return nil, errInvalidMessageRole
}
role = b.String()
case stateComment, stateNil:
// pass
case stateValue:
s, ok := unquote(b.String())
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
if role != "" {
s = role + ": " + s
role = ""
}
cmd.Args = s
cmds = append(cmds, cmd)
}
b.Reset()
curr = next
}
if strconv.IsPrint(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
}
} }
return 0, nil, nil // flush the buffer
switch curr {
case stateComment, stateNil:
// pass; nothing to flush
case stateValue:
s, ok := unquote(b.String())
if !ok {
return nil, io.ErrUnexpectedEOF
}
if role != "" {
s = role + ": " + s
}
cmd.Args = s
cmds = append(cmds, cmd)
default:
return nil, io.ErrUnexpectedEOF
}
for _, cmd := range cmds {
if cmd.Name == "model" {
return cmds, nil
}
}
return nil, errMissingFrom
}
func parseRuneForState(r rune, cs state) (state, rune, error) {
switch cs {
case stateNil:
switch {
case r == '#':
return stateComment, 0, nil
case isSpace(r), isNewline(r):
return stateNil, 0, nil
default:
return stateName, r, nil
}
case stateName:
switch {
case isAlpha(r):
return stateName, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, errInvalidCommand
}
case stateValue:
switch {
case isNewline(r):
return stateNil, r, nil
case isSpace(r):
return stateNil, r, nil
default:
return stateValue, r, nil
}
case stateParameter:
switch {
case isAlpha(r), isNumber(r), r == '_':
return stateParameter, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateMessage:
switch {
case isAlpha(r):
return stateMessage, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
return stateNil, 0, io.ErrUnexpectedEOF
}
case stateComment:
switch {
case isNewline(r):
return stateNil, 0, nil
default:
return stateComment, 0, nil
}
default:
return stateNil, 0, errors.New("")
}
}
func quote(s string) string {
if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
if strings.Contains(s, "\"") {
return `"""` + s + `"""`
}
return `"` + s + `"`
}
return s
}
func unquote(s string) (string, bool) {
if len(s) == 0 {
return "", false
}
// TODO: single quotes
if len(s) >= 3 && s[:3] == `"""` {
if len(s) >= 6 && s[len(s)-3:] == `"""` {
return s[3 : len(s)-3], true
}
return "", false
}
if len(s) >= 1 && s[0] == '"' {
if len(s) >= 2 && s[len(s)-1] == '"' {
return s[1 : len(s)-1], true
}
return "", false
}
return s, true
}
func isAlpha(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
}
func isNumber(r rune) bool {
return r >= '0' && r <= '9'
}
func isSpace(r rune) bool {
return r == ' ' || r == '\t'
}
func isNewline(r rune) bool {
return r == '\r' || r == '\n'
}
func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant"
}
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "parameter", "message":
return true
default:
return false
}
} }

View file

@ -1,14 +1,16 @@
package parser package parser
import ( import (
"bytes"
"fmt"
"io"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_Parser(t *testing.T) { func TestParser(t *testing.T) {
input := ` input := `
FROM model1 FROM model1
ADAPTER adapter1 ADAPTER adapter1
@ -35,21 +37,62 @@ TEMPLATE template1
assert.Equal(t, expectedCommands, commands) assert.Equal(t, expectedCommands, commands)
} }
func Test_Parser_NoFromLine(t *testing.T) { func TestParserFrom(t *testing.T) {
var cases = []struct {
input string
expected []Command
err error
}{
{
"FROM foo",
[]Command{{Name: "model", Args: "foo"}},
nil,
},
{
"FROM /path/to/model",
[]Command{{Name: "model", Args: "/path/to/model"}},
nil,
},
{
"FROM /path/to/model/fp16.bin",
[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
nil,
},
{
"FROM llama3:latest",
[]Command{{Name: "model", Args: "llama3:latest"}},
nil,
},
{
"FROM llama3:7b-instruct-q4_K_M",
[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
nil,
},
{
"", nil, errMissingFrom,
},
{
"PARAMETER param1 value1",
nil,
errMissingFrom,
},
{
"PARAMETER param1 value1\nFROM foo",
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
nil,
},
}
input := ` for _, c := range cases {
PARAMETER param1 value1 t.Run("", func(t *testing.T) {
PARAMETER param2 value2 commands, err := Parse(strings.NewReader(c.input))
` assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
reader := strings.NewReader(input) })
}
_, err := Parse(reader)
assert.ErrorContains(t, err, "no FROM line")
} }
func Test_Parser_MissingValue(t *testing.T) { func TestParserParametersMissingValue(t *testing.T) {
input := ` input := `
FROM foo FROM foo
PARAMETER param1 PARAMETER param1
@ -58,41 +101,401 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := Parse(reader) _, err := Parse(reader)
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}
func TestParserBadCommand(t *testing.T) {
input := `
FROM foo
BADCOMMAND param1 value1
`
_, err := Parse(strings.NewReader(input))
assert.ErrorIs(t, err, errInvalidCommand)
} }
func Test_Parser_Messages(t *testing.T) { func TestParserMessages(t *testing.T) {
var cases = []struct {
input := ` input string
expected []Command
err error
}{
{
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
},
nil,
},
{
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE system You are a Parser. Always Parse things. MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there! MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things! MESSAGE assistant Hello, I want to parse all the things!
` `,
[]Command{
reader := strings.NewReader(input) {Name: "model", Args: "foo"},
commands, err := Parse(reader) {Name: "message", Args: "system: You are a Parser. Always Parse things."},
assert.Nil(t, err) {Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
expectedCommands := []Command{ },
{Name: "model", Args: "foo"}, nil,
{Name: "message", Args: "system: You are a Parser. Always Parse things."}, },
{Name: "message", Args: "user: Hey there!"}, {
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, `
} FROM foo
MESSAGE system """
assert.Equal(t, expectedCommands, commands) You are a multiline Parser. Always Parse things.
} """
`,
func Test_Parser_Messages_BadRole(t *testing.T) { []Command{
{Name: "model", Args: "foo"},
input := ` {Name: "message", Args: "system: \nYou are a multiline Parser. Always Parse things.\n"},
},
nil,
},
{
`
FROM foo FROM foo
MESSAGE badguy I'm a bad guy! MESSAGE badguy I'm a bad guy!
` `,
nil,
errInvalidMessageRole,
},
{
`
FROM foo
MESSAGE system
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
MESSAGE system`,
nil,
io.ErrUnexpectedEOF,
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParserQuoted(t *testing.T) {
var cases = []struct {
multiline string
expected []Command
err error
}{
{
`
FROM foo
SYSTEM """
This is a
multiline system.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "\nThis is a\nmultiline system.\n"},
},
nil,
},
{
`
FROM foo
SYSTEM """
This is a
multiline system."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "\nThis is a\nmultiline system."},
},
nil,
},
{
`
FROM foo
SYSTEM """This is a
multiline system."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "This is a\nmultiline system."},
},
nil,
},
{
`
FROM foo
SYSTEM """This is a multiline system."""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "This is a multiline system."},
},
nil,
},
{
`
FROM foo
SYSTEM """This is a multiline system.""
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
SYSTEM "
`,
nil,
io.ErrUnexpectedEOF,
},
{
`
FROM foo
SYSTEM """
This is a multiline system with "quotes".
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
},
nil,
},
{
`
FROM foo
SYSTEM """"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: ""},
},
nil,
},
{
`
FROM foo
SYSTEM ""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: ""},
},
nil,
},
{
`
FROM foo
SYSTEM "'"
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: "'"},
},
nil,
},
{
`
FROM foo
SYSTEM """''"'""'""'"'''''""'""'"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "system", Args: `''"'""'""'"'''''""'""'`},
},
nil,
},
{
`
FROM foo
TEMPLATE """
{{ .Prompt }}
"""`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "template", Args: "\n{{ .Prompt }}\n"},
},
nil,
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParserParameters(t *testing.T) {
var cases = map[string]struct {
name, value string
}{
"numa true": {"numa", "true"},
"num_ctx 1": {"num_ctx", "1"},
"num_batch 1": {"num_batch", "1"},
"num_gqa 1": {"num_gqa", "1"},
"num_gpu 1": {"num_gpu", "1"},
"main_gpu 1": {"main_gpu", "1"},
"low_vram true": {"low_vram", "true"},
"f16_kv true": {"f16_kv", "true"},
"logits_all true": {"logits_all", "true"},
"vocab_only true": {"vocab_only", "true"},
"use_mmap true": {"use_mmap", "true"},
"use_mlock true": {"use_mlock", "true"},
"num_thread 1": {"num_thread", "1"},
"num_keep 1": {"num_keep", "1"},
"seed 1": {"seed", "1"},
"num_predict 1": {"num_predict", "1"},
"top_k 1": {"top_k", "1"},
"top_p 1.0": {"top_p", "1.0"},
"tfs_z 1.0": {"tfs_z", "1.0"},
"typical_p 1.0": {"typical_p", "1.0"},
"repeat_last_n 1": {"repeat_last_n", "1"},
"temperature 1.0": {"temperature", "1.0"},
"repeat_penalty 1.0": {"repeat_penalty", "1.0"},
"presence_penalty 1.0": {"presence_penalty", "1.0"},
"frequency_penalty 1.0": {"frequency_penalty", "1.0"},
"mirostat 1": {"mirostat", "1"},
"mirostat_tau 1.0": {"mirostat_tau", "1.0"},
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
"penalize_newline true": {"penalize_newline", "true"},
"stop ### User:": {"stop", "### User:"},
"stop ### User: ": {"stop", "### User: "},
"stop \"### User:\"": {"stop", "### User:"},
"stop \"### User: \"": {"stop", "### User: "},
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
"stop <|endoftext|>": {"stop", "<|endoftext|>"},
"stop <|eot_id|>": {"stop", "<|eot_id|>"},
"stop </s>": {"stop", "</s>"},
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
var b bytes.Buffer
fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", k)
commands, err := Parse(&b)
assert.Nil(t, err)
assert.Equal(t, []Command{
{Name: "model", Args: "foo"},
{Name: v.name, Args: v.value},
}, commands)
})
}
}
func TestParserComments(t *testing.T) {
var cases = []struct {
input string
expected []Command
}{
{
`
# comment
FROM foo
`,
[]Command{
{Name: "model", Args: "foo"},
},
},
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c.input))
assert.Nil(t, err)
assert.Equal(t, c.expected, commands)
})
}
}
func TestParseFormatParse(t *testing.T) {
var cases = []string{
`
FROM foo
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
`
FROM foo
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system """
You are a store greeter. Always responsed with "Hello!".
"""
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
`
FROM foo
ADAPTER adapter1
LICENSE """
Very long and boring legal text.
Blah blah blah.
"Oh look, a quote!"
"""
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system """
You are a store greeter. Always responsed with "Hello!".
"""
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c))
assert.NoError(t, err)
commands2, err := Parse(strings.NewReader(Format(commands)))
assert.NoError(t, err)
assert.Equal(t, commands, commands2)
})
}
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
} }

View file

@ -21,7 +21,6 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"text/template"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
@ -64,6 +63,48 @@ func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
} }
func (m *Model) Commands() (cmds []parser.Command) {
cmds = append(cmds, parser.Command{Name: "model", Args: m.ModelPath})
if m.Template != "" {
cmds = append(cmds, parser.Command{Name: "template", Args: m.Template})
}
if m.System != "" {
cmds = append(cmds, parser.Command{Name: "system", Args: m.System})
}
for _, adapter := range m.AdapterPaths {
cmds = append(cmds, parser.Command{Name: "adapter", Args: adapter})
}
for _, projector := range m.ProjectorPaths {
cmds = append(cmds, parser.Command{Name: "projector", Args: projector})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
for _, s := range v {
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", s)})
}
default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", v)})
}
}
for _, license := range m.License {
cmds = append(cmds, parser.Command{Name: "license", Args: license})
}
for _, msg := range m.Messages {
cmds = append(cmds, parser.Command{Name: "message", Args: fmt.Sprintf("%s %s", msg.Role, msg.Content)})
}
return cmds
}
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
@ -901,67 +942,6 @@ func DeleteModel(name string) error {
return nil return nil
} }
func ShowModelfile(model *Model) (string, error) {
var mt struct {
*Model
From string
Parameters map[string][]any
}
mt.Parameters = make(map[string][]any)
for k, v := range model.Options {
if s, ok := v.([]any); ok {
mt.Parameters[k] = s
continue
}
mt.Parameters[k] = []any{v}
}
mt.Model = model
mt.From = model.ModelPath
if model.ParentModel != "" {
mt.From = model.ParentModel
}
modelFile := `# Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with:
# FROM {{ .ShortName }}
FROM {{ .From }}
TEMPLATE """{{ .Template }}"""
{{- if .System }}
SYSTEM """{{ .System }}"""
{{- end }}
{{- range $adapter := .AdapterPaths }}
ADAPTER {{ $adapter }}
{{- end }}
{{- range $k, $v := .Parameters }}
{{- range $parameter := $v }}
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
{{- end }}
{{- end }}`
tmpl, err := template.New("").Parse(modelFile)
if err != nil {
slog.Info(fmt.Sprintf("error parsing template: %q", err))
return "", err
}
var buf bytes.Buffer
if err = tmpl.Execute(&buf, mt); err != nil {
slog.Info(fmt.Sprintf("error executing template: %q", err))
return "", err
}
return buf.String(), nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})

View file

@ -728,12 +728,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
} }
mf, err := ShowModelfile(model) var sb strings.Builder
if err != nil { fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
return nil, err fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
} fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
fmt.Fprint(&sb, parser.Format(model.Commands()))
resp.Modelfile = mf resp.Modelfile = sb.String()
return resp, nil return resp, nil
} }

View file

@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
if tc.Expected != nil { if tc.Expected != nil {
tc.Expected(t, resp) tc.Expected(t, resp)
} }
} }
} }