parser: add commands format

This commit is contained in:
Michael Yang 2024-04-24 18:49:14 -07:00
parent 4d08363580
commit 176ad3aa6e
3 changed files with 108 additions and 15 deletions

View file

@ -17,7 +17,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
@ -57,12 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
modelfile, err := os.ReadFile(filename) modelfile, err := os.Open(filename)
if err != nil { if err != nil {
return err return err
} }
defer modelfile.Close()
commands, err := parser.Parse(bytes.NewReader(modelfile)) commands, err := parser.Parse(modelfile)
if err != nil { if err != nil {
return err return err
} }
@ -76,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
for _, c := range commands { for i := range commands {
switch c.Name { switch commands[i].Name {
case "model", "adapter": case "model", "adapter":
path := c.Args path := commands[i].Args
if path == "~" { if path == "~" {
path = home path = home
} else if strings.HasPrefix(path, "~/") { } else if strings.HasPrefix(path, "~/") {
@ -91,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
fi, err := os.Stat(path) fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" { if errors.Is(err, os.ErrNotExist) && commands[i].Name == "model" {
continue continue
} else if err != nil { } else if err != nil {
return err return err
@ -114,13 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
name := c.Name commands[i].Args = "@"+digest
if c.Name == "model" {
name = "from"
}
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
} }
} }
@ -150,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
quantization, _ := cmd.Flags().GetString("quantization") quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} request := api.CreateRequest{Name: args[0], Modelfile: parser.Format(commands), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil { if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err return err
} }

View file

@ -31,6 +31,33 @@ var (
errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"") errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
) )
func Format(cmds []Command) string {
var b bytes.Buffer
for _, cmd := range cmds {
name := cmd.Name
args := cmd.Args
switch cmd.Name {
case "model":
name = "from"
args = cmd.Args
case "license", "template", "system", "adapter":
args = quote(args)
// pass
case "message":
role, message, _ := strings.Cut(cmd.Args, ": ")
args = role + " " + quote(message)
default:
name = "parameter"
args = cmd.Name + " " + cmd.Args
}
fmt.Fprintln(&b, strings.ToUpper(name), args)
}
return b.String()
}
func Parse(r io.Reader) (cmds []Command, err error) { func Parse(r io.Reader) (cmds []Command, err error) {
var cmd Command var cmd Command
var curr state var curr state
@ -197,6 +224,18 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
} }
} }
func quote(s string) string {
if strings.Contains(s, "\n") || strings.HasSuffix(s, " ") {
if strings.Contains(s, "\"") {
return `"""` + s + `"""`
}
return strconv.Quote(s)
}
return s
}
func unquote(s string) (string, bool) { func unquote(s string) (string, bool) {
if len(s) == 0 { if len(s) == 0 {
return "", false return "", false

View file

@ -429,3 +429,63 @@ FROM foo
}) })
} }
} }
func TestParseFormatParse(t *testing.T) {
var cases = []string{
`
FROM foo
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
`
FROM foo
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system """
You are a store greeter. Always responsed with "Hello!".
"""
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
`
FROM foo
ADAPTER adapter1
LICENSE """
Very long and boring legal text.
Blah blah blah.
"Oh look, a quote!"
"""
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
MESSAGE system """
You are a store greeter. Always responsed with "Hello!".
"""
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`,
}
for _, c := range cases {
t.Run("", func(t *testing.T) {
commands, err := Parse(strings.NewReader(c))
assert.NoError(t, err)
commands2, err := Parse(strings.NewReader(Format(commands)))
assert.NoError(t, err)
assert.Equal(t, commands, commands2)
})
}
}