From df146c41e27e6e15c0ab51af9070c87c82aa4fc7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 17 Jul 2023 14:21:27 -0700 Subject: [PATCH] separate prompt into template and system --- parser/parser.go | 101 +++++++++++++++++++++++++++-------------------- server/images.go | 85 ++++++++++++++++++++++++--------------- server/routes.go | 12 +----- 3 files changed, 113 insertions(+), 85 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index 5ba87c84..635e7276 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -2,76 +2,91 @@ package parser import ( "bufio" + "bytes" + "errors" "fmt" "io" - "strings" ) type Command struct { Name string - Arg string + Args string +} + +func (c *Command) Reset() { + c.Name = "" + c.Args = "" } func Parse(reader io.Reader) ([]Command, error) { var commands []Command - var foundModel bool + + var command, modelCommand Command scanner := bufio.NewScanner(reader) - multiline := false - var multilineCommand *Command + scanner.Split(scanModelfile) for scanner.Scan() { - line := scanner.Text() - if multiline { - // If we're in a multiline string and the line is """, end the multiline string. - if strings.TrimSpace(line) == `"""` { - multiline = false - commands = append(commands, *multilineCommand) - } else { - // Otherwise, append the line to the multiline string. - multilineCommand.Arg += "\n" + line - } - continue - } - fields := strings.Fields(line) + line := scanner.Bytes() + + fields := bytes.SplitN(line, []byte(" "), 2) if len(fields) == 0 { continue } - command := Command{} - switch strings.ToUpper(fields[0]) { + switch string(bytes.ToUpper(fields[0])) { case "FROM": command.Name = "model" - command.Arg = fields[1] - if command.Arg == "" { - return nil, fmt.Errorf("no model specified in FROM line") - } - foundModel = true - case "PROMPT", "LICENSE": - command.Name = strings.ToLower(fields[0]) - if fields[1] == `"""` { - multiline = true - multilineCommand = &command - multilineCommand.Arg = "" - } else { - command.Arg = strings.Join(fields[1:], " ") - } + command.Args = string(fields[1]) + // copy command for validation + modelCommand = command + case "LICENSE", "TEMPLATE", "SYSTEM": + command.Name = string(bytes.ToLower(fields[0])) + command.Args = string(fields[1]) case "PARAMETER": - command.Name = fields[1] - command.Arg = strings.Join(fields[2:], " ") + fields = bytes.SplitN(fields[1], []byte(" "), 2) + command.Name = string(fields[0]) + command.Args = string(fields[1]) default: continue } - if !multiline { - commands = append(commands, command) - } + + commands = append(commands, command) + command.Reset() } - if !foundModel { + if modelCommand.Args == "" { return nil, fmt.Errorf("no FROM line for the model was specified") } - if multiline { - return nil, fmt.Errorf("unclosed multiline string") - } return commands, scanner.Err() } + +func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF || len(data) == 0 { + return 0, nil, nil + } + + newline := bytes.IndexByte(data, '\n') + + if start := bytes.Index(data, []byte(`"""`)); start >= 0 && start < newline { + end := bytes.Index(data[start+3:], []byte(`"""`)) + if end < 0 { + return 0, nil, errors.New(`unterminated multiline string: """`) + } + + n := start + 3 + end + 3 + return n, bytes.Replace(data[:n], []byte(`"""`), []byte(""), 2), nil + } + + if start := bytes.Index(data, []byte(`'''`)); start >= 0 && start < newline { + end := bytes.Index(data[start+3:], []byte(`'''`)) + if end < 0 { + return 0, nil, errors.New("unterminated multiline string: '''") + } + + n := start + 3 + end + 3 + return n, bytes.Replace(data[:n], []byte("'''"), []byte(""), 2), nil + } + + return bufio.ScanLines(data, atEOF) +} diff --git a/server/images.go b/server/images.go index 4140283a..16b32f83 100644 --- a/server/images.go +++ b/server/images.go @@ -16,6 +16,7 @@ import ( "reflect" "strconv" "strings" + "text/template" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/parser" @@ -24,10 +25,33 @@ import ( type Model struct { Name string `json:"name"` ModelPath string - Prompt string + Template string + System string Options api.Options } +func (m *Model) Prompt(request api.GenerateRequest) (string, error) { + tmpl, err := template.New("").Parse(m.Template) + if err != nil { + return "", err + } + + var vars struct { + System string + Prompt string + } + + vars.System = m.System + vars.Prompt = request.Prompt + + var sb strings.Builder + if err := tmpl.Execute(&sb, vars); err != nil { + return "", err + } + + return sb.String(), nil +} + type ManifestV2 struct { SchemaVersion int `json:"schemaVersion"` MediaType string `json:"mediaType"` @@ -71,20 +95,19 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) { if err != nil { return nil, err } + if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname()) } var manifest *ManifestV2 - f, err := os.Open(fp) + bts, err := os.ReadFile(fp) if err != nil { return nil, fmt.Errorf("couldn't open file '%s'", fp) } - decoder := json.NewDecoder(f) - err = decoder.Decode(&manifest) - if err != nil { + if err := json.Unmarshal(bts, &manifest); err != nil { return nil, err } @@ -112,12 +135,20 @@ func GetModel(name string) (*Model, error) { switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename - case "application/vnd.ollama.image.prompt": - data, err := os.ReadFile(filename) + case "application/vnd.ollama.image.template": + bts, err := os.ReadFile(filename) if err != nil { return nil, err } - model.Prompt = string(data) + + model.Template = string(bts) + case "application/vnd.ollama.image.system": + bts, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + model.System = string(bts) case "application/vnd.ollama.image.params": params, err := os.Open(filename) if err != nil { @@ -156,13 +187,13 @@ func CreateModel(name string, path string, fn func(status string)) error { params := make(map[string]string) for _, c := range commands { - log.Printf("[%s] - %s\n", c.Name, c.Arg) + log.Printf("[%s] - %s\n", c.Name, c.Args) switch c.Name { case "model": fn("looking for model") - mf, err := GetManifest(ParseModelPath(c.Arg)) + mf, err := GetManifest(ParseModelPath(c.Args)) if err != nil { - fp := c.Arg + fp := c.Args // If filePath starts with ~/, replace it with the user's home directory. if strings.HasPrefix(fp, "~/") { @@ -183,7 +214,7 @@ func CreateModel(name string, path string, fn func(status string)) error { fn("creating model layer") file, err := os.Open(fp) if err != nil { - fn(fmt.Sprintf("couldn't find model '%s'", c.Arg)) + fn(fmt.Sprintf("couldn't find model '%s'", c.Args)) return fmt.Errorf("failed to open file: %v", err) } defer file.Close() @@ -206,31 +237,21 @@ func CreateModel(name string, path string, fn func(status string)) error { layers = append(layers, newLayer) } } - case "prompt": - fn("creating prompt layer") + case "license", "template", "system": + fn(fmt.Sprintf("creating %s layer", c.Name)) // remove the prompt layer if one exists - layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.prompt") + mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + layers = removeLayerFromLayers(layers, mediaType) - prompt := strings.NewReader(c.Arg) - l, err := CreateLayer(prompt) + layer, err := CreateLayer(strings.NewReader(c.Args)) if err != nil { - fn(fmt.Sprintf("couldn't create prompt layer: %v", err)) - return fmt.Errorf("failed to create layer: %v", err) + return err } - l.MediaType = "application/vnd.ollama.image.prompt" - layers = append(layers, l) - case "license": - fn("creating license layer") - license := strings.NewReader(c.Arg) - l, err := CreateLayer(license) - if err != nil { - fn(fmt.Sprintf("couldn't create license layer: %v", err)) - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.license" - layers = append(layers, l) + + layer.MediaType = mediaType + layers = append(layers, layer) default: - params[c.Name] = c.Arg + params[c.Name] = c.Args } } diff --git a/server/routes.go b/server/routes.go index 46e05989..f97ab120 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,7 +9,6 @@ import ( "os" "path/filepath" "strings" - "text/template" "time" "dario.cat/mergo" @@ -54,19 +53,12 @@ func generate(c *gin.Context) { return } - templ, err := template.New("").Parse(model.Prompt) + prompt, err := model.Prompt(req) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - var sb strings.Builder - if err = templ.Execute(&sb, req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - req.Prompt = sb.String() - llm, err := llama.New(model.ModelPath, opts) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -77,7 +69,7 @@ func generate(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) { + llm.Predict(req.Context, prompt, func(r api.GenerateResponse) { r.Model = req.Model r.CreatedAt = time.Now().UTC() if r.Done {