Save and load sessions (#2063)

This commit is contained in:
Patrick Devine 2024-01-25 12:12:36 -08:00 committed by GitHub
parent e64b5b07a2
commit 7c40a67841
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 312 additions and 39 deletions

View file

@ -171,6 +171,7 @@ type ShowResponse struct {
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
} }
type CopyRequest struct { type CopyRequest struct {
@ -236,6 +237,7 @@ type GenerateResponse struct {
} }
type ModelDetails struct { type ModelDetails struct {
ParentModel string `json:"parent_model"`
Format string `json:"format"` Format string `json:"format"`
Family string `json:"family"` Family string `json:"family"`
Families []string `json:"families"` Families []string `json:"families"`

View file

@ -459,6 +459,7 @@ type generateContextKey string
type runOptions struct { type runOptions struct {
Model string Model string
ParentModel string
Prompt string Prompt string
Messages []api.Message Messages []api.Message
WordWrap bool WordWrap bool
@ -467,6 +468,7 @@ type runOptions struct {
Template string Template string
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool
} }
type displayResponseState struct { type displayResponseState struct {

View file

@ -7,12 +7,14 @@ import (
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
"sort"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/progress"
"github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/readline"
) )
@ -25,33 +27,63 @@ const (
MultilineTemplate MultilineTemplate
) )
func modelIsMultiModal(cmd *cobra.Command, name string) bool { func loadModel(cmd *cobra.Command, opts *runOptions) error {
// get model details
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
fmt.Println("error: couldn't connect to ollama server") return err
return false
} }
req := api.ShowRequest{Name: name} p := progress.NewProgress(os.Stderr)
resp, err := client.Show(cmd.Context(), &req) defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
showReq := api.ShowRequest{Name: opts.Model}
showResp, err := client.Show(cmd.Context(), &showReq)
if err != nil { if err != nil {
return false return err
}
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
opts.ParentModel = showResp.Details.ParentModel
if len(showResp.Messages) > 0 {
opts.Messages = append(opts.Messages, showResp.Messages...)
} }
return slices.Contains(resp.Details.Families, "clip") chatReq := &api.ChatRequest{
Model: opts.Model,
Messages: []api.Message{},
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear()
if len(opts.Messages) > 0 {
for _, msg := range opts.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
}
return nil
})
if err != nil {
return err
}
return nil
} }
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model) opts.Messages = make([]api.Message, 0)
// load the model err := loadModel(cmd, &opts)
loadOpts := runOptions{ if err != nil {
Model: opts.Model,
Prompt: "",
Messages: []api.Message{},
}
if _, err := chat(cmd, loadOpts); err != nil {
return err return err
} }
@ -59,6 +91,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables") fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information") fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
opts.Messages = make([]api.Message, 0)
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if err := ListHandler(cmd, args[1:]); err != nil { if err := ListHandler(cmd, args[1:]); err != nil {
return err return err
} }
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /load <modelname>")
continue
}
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil {
return err
}
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.CreateRequest{
Name: args[1],
Modelfile: buildModelfile(opts),
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Println("error: couldn't save model")
return err
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/set"): case strings.HasPrefix(line, "/set"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
args := strings.Fields(line) args := strings.Fields(line)
isFile := false isFile := false
if multiModal { if opts.MultiModal {
for _, f := range extractFileNames(line) { for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) { if strings.HasPrefix(f, args[0]) {
isFile = true isFile = true
@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if sb.Len() > 0 && multiline == MultilineNone { if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
if multiModal { if opts.MultiModal {
msg, images, err := extractFileData(sb.String()) msg, images, err := extractFileData(sb.String())
if err != nil { if err != nil {
return err return err
@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
} }
func buildModelfile(opts runOptions) string {
var mf strings.Builder
model := opts.ParentModel
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
}
fmt.Fprintln(&mf)
for _, msg := range opts.Messages {
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
}
return mf.String()
}
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements // Define a map of escaped characters and their replacements
replacements := map[string]string{ replacements := map[string]string{

View file

@ -1,9 +1,13 @@
package cmd package cmd
import ( import (
"bytes"
"testing" "testing"
"text/template"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
) )
func TestExtractFilenames(t *testing.T) { func TestExtractFilenames(t *testing.T) {
@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "ten.svg")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
} }
func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
},
Options: map[string]interface{}{},
}
opts.Options["temperature"] = 0.9
opts.Options["seed"] = 42
opts.Options["penalize_newline"] = false
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err := template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, opts)
assert.Nil(t, err)
assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark"
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err = template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err)
assert.Equal(t, parentBuf.String(), mf)
}

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"slices"
) )
type Command struct { type Command struct {
@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(bytes.TrimSpace(fields[1])) command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED": case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands // log a warning for unknown commands

View file

@ -61,3 +61,38 @@ PARAMETER param1
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorContains(t, err, "missing value for [param1]")
} }
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}

View file

@ -41,7 +41,7 @@ type Model struct {
Config ConfigV2 Config ConfigV2
ShortName string ShortName string
ModelPath string ModelPath string
OriginalModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string Template string
@ -50,6 +50,12 @@ type Model struct {
Digest string Digest string
Size int64 Size int64
Options map[string]interface{} Options map[string]interface{}
Messages []Message
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
} }
type PromptVars struct { type PromptVars struct {
@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
model.OriginalModel = layer.From model.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2 // Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version // TODO: remove this warning in a future version
@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil { if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license": case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
var layers Layers var layers Layers
messages := []string{}
params := make(map[string][]string) params := make(map[string][]string)
fromParams := make(map[string]any) fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
layers.Replace(layer) layers.Replace(layer)
case "message":
messages = append(messages, c.Args)
default: default:
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
return err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return err
}
layers.Replace(layer)
}
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
mt.Model = model mt.Model = model
mt.From = model.ModelPath mt.From = model.ModelPath
if model.OriginalModel != "" { if model.ParentModel != "" {
mt.From = model.OriginalModel mt.From = model.ParentModel
} }
modelFile := `# Modelfile generated by "ollama show" modelFile := `# Modelfile generated by "ollama show"

View file

@ -659,6 +659,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
modelDetails := api.ModelDetails{ modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Format: model.Config.ModelFormat, Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily, Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies, Families: model.Config.ModelFamilies,
@ -674,11 +675,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model.Template = req.Template model.Template = req.Template
} }
msgs := make([]api.Message, 0)
for _, msg := range model.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"), License: strings.Join(model.License, "\n"),
System: model.System, System: model.System,
Template: model.Template, Template: model.Template,
Details: modelDetails, Details: modelDetails,
Messages: msgs,
} }
var params []string var params []string
@ -1075,7 +1082,13 @@ func ChatHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}}) resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return return
} }