Merge pull request from ollama/mxyng/save

fix: model save
This commit is contained in:
Michael Yang 2024-07-29 09:53:19 -07:00 committed by GitHub
commit 38d9036b59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 60 additions and 61 deletions

View file

@ -1,6 +1,7 @@
package cmd package cmd
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -9,13 +10,14 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"slices" "slices"
"sort"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
@ -376,9 +378,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
req := &api.ShowRequest{ req := &api.ShowRequest{
Name: opts.Model, Name: opts.Model,
System: opts.System, System: opts.System,
Options: opts.Options, Options: opts.Options,
} }
resp, err := client.Show(cmd.Context(), req) resp, err := client.Show(cmd.Context(), req)
if err != nil { if err != nil {
@ -507,31 +509,35 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
func buildModelfile(opts runOptions) string { func buildModelfile(opts runOptions) string {
var mf strings.Builder var f parser.File
model := opts.ParentModel f.Commands = append(f.Commands, parser.Command{Name: "model", Args: cmp.Or(opts.ParentModel, opts.Model)})
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" { if opts.System != "" {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System) f.Commands = append(f.Commands, parser.Command{Name: "system", Args: opts.System})
} }
keys := make([]string, 0) keys := maps.Keys(opts.Options)
for k := range opts.Options { slices.Sort(keys)
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k]) v := opts.Options[k]
var cmds []parser.Command
switch t := v.(type) {
case []string:
for _, s := range t {
cmds = append(cmds, parser.Command{Name: k, Args: s})
}
default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", t)})
}
f.Commands = append(f.Commands, cmds...)
} }
fmt.Fprintln(&mf)
for _, msg := range opts.Messages { for _, msg := range opts.Messages {
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content) f.Commands = append(f.Commands, parser.Command{Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content)})
} }
return mf.String() return f.String()
} }
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {

View file

@ -1,12 +1,10 @@
package cmd package cmd
import ( import (
"bytes"
"testing" "testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -57,58 +55,53 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
func TestModelfileBuilder(t *testing.T) { func TestModelfileBuilder(t *testing.T) {
opts := runOptions{ opts := runOptions{
Model: "hork", Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things", System: "You are part horse and part shark, but all hork. Do horklike things",
Messages: []api.Message{ Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"}, {Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
}, },
Options: map[string]interface{}{}, Options: map[string]any{
"temperature": 0.9,
"seed": 42,
"penalize_newline": false,
"stop": []string{"hi", "there"},
},
} }
opts.Options["temperature"] = 0.9 t.Run("model", func(t *testing.T) {
opts.Options["seed"] = 42 expect := `FROM hork
opts.Options["penalize_newline"] = false SYSTEM You are part horse and part shark, but all hork. Do horklike things
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop hi
PARAMETER stop there
PARAMETER temperature 0.9 PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE user """Hey there hork!""" MESSAGE assistant Yes it is true, I am half horse, half shark.
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
tmpl, err := template.New("").Parse(expectedModelfile) actual := buildModelfile(opts)
require.NoError(t, err) if diff := cmp.Diff(expect, actual); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
var buf bytes.Buffer t.Run("parent model", func(t *testing.T) {
err = tmpl.Execute(&buf, opts) opts.ParentModel = "horseshark"
require.NoError(t, err) expect := `FROM horseshark
assert.Equal(t, buf.String(), mf) SYSTEM You are part horse and part shark, but all hork. Do horklike things
opts.ParentModel = "horseshark"
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop hi
PARAMETER stop there
PARAMETER temperature 0.9 PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE user """Hey there hork!""" MESSAGE assistant Yes it is true, I am half horse, half shark.
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
actual := buildModelfile(opts)
tmpl, err = template.New("").Parse(expectedModelfile) if diff := cmp.Diff(expect, actual); diff != "" {
require.NoError(t, err) t.Errorf("mismatch (-want +got):\n%s", diff)
}
var parentBuf bytes.Buffer })
err = tmpl.Execute(&parentBuf, opts)
require.NoError(t, err)
assert.Equal(t, parentBuf.String(), mf)
} }