Convert the REPL to use /api/chat for interactive responses (#1936)

This commit is contained in:
Patrick Devine 2024-01-12 12:05:52 -08:00 committed by GitHub
parent 40a0a90a88
commit 565f8a3c44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 155 additions and 72 deletions

View file

@ -35,8 +35,6 @@ import (
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
type ImageData []byte
func CreateHandler(cmd *cobra.Command, args []string) error { func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file") filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename) filename, err := filepath.Abs(filename)
@ -415,11 +413,10 @@ func PullHandler(cmd *cobra.Command, args []string) error {
func RunGenerate(cmd *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
opts := generateOptions{ opts := runOptions{
Model: args[0], Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color", WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{}, Options: map[string]interface{}{},
Images: []ImageData{},
} }
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
@ -460,18 +457,135 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
type generateContextKey string type generateContextKey string
type generateOptions struct { type runOptions struct {
Model string Model string
Prompt string Prompt string
Messages []api.Message
WordWrap bool WordWrap bool
Format string Format string
System string System string
Template string Template string
Images []ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
} }
func generate(cmd *cobra.Command, opts generateOptions) error { type displayResponseState struct {
lineLength int
wordBuffer string
}
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if wordWrap && termWidth >= 10 {
for _, ch := range content {
if state.lineLength+1 > termWidth-5 {
if len(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = ""
state.lineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer))
fmt.Printf("%s%c", state.wordBuffer, ch)
state.lineLength = len(state.wordBuffer) + 1
} else {
fmt.Print(string(ch))
state.lineLength += 1
switch ch {
case ' ':
state.wordBuffer = ""
case '\n':
state.lineLength = 0
default:
state.wordBuffer += string(ch)
}
}
}
} else {
fmt.Printf("%s%s", state.wordBuffer, content)
if len(state.wordBuffer) > 0 {
state.wordBuffer = ""
}
}
}
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
fn := func(response api.ChatResponse) error {
p.StopAndClear()
latest = response
role = response.Message.Role
content := response.Message.Content
fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state)
return nil
}
req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: opts.Format,
Options: opts.Options,
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
return nil, err
}
if len(opts.Messages) > 0 {
fmt.Println()
fmt.Println()
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return nil, err
}
if verbose {
latest.Summary()
}
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
@ -490,11 +604,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
generateContext = []int{} generateContext = []int{}
} }
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
opts.WordWrap = false
}
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
defer cancel() defer cancel()
@ -506,57 +615,19 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
cancel() cancel()
}() }()
var currentLineLength int var state *displayResponseState = &displayResponseState{}
var wordBuffer string
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
p.StopAndClear() p.StopAndClear()
latest = response latest = response
content := response.Response
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd())) displayResponse(content, opts.WordWrap, state)
if opts.WordWrap && termWidth >= 10 {
for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 {
if len(wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", wordBuffer, ch)
wordBuffer = ""
currentLineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
fmt.Printf("%s%c", wordBuffer, ch)
currentLineLength = len(wordBuffer) + 1
} else {
fmt.Print(string(ch))
currentLineLength += 1
switch ch {
case ' ':
wordBuffer = ""
case '\n':
currentLineLength = 0
default:
wordBuffer += string(ch)
}
}
}
} else {
fmt.Printf("%s%s", wordBuffer, response.Response)
if len(wordBuffer) > 0 {
wordBuffer = ""
}
}
return nil return nil
} }
images := make([]api.ImageData, 0)
for _, i := range opts.Images {
images = append(images, api.ImageData(i))
}
request := api.GenerateRequest{ request := api.GenerateRequest{
Model: opts.Model, Model: opts.Model,
Prompt: opts.Prompt, Prompt: opts.Prompt,
@ -565,7 +636,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
System: opts.System, System: opts.System,
Template: opts.Template, Template: opts.Template,
Options: opts.Options, Options: opts.Options,
Images: images,
} }
if err := client.Generate(ctx, &request, fn); err != nil { if err := client.Generate(ctx, &request, fn); err != nil {

View file

@ -1,7 +1,6 @@
package cmd package cmd
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -43,16 +42,16 @@ func modelIsMultiModal(cmd *cobra.Command, name string) bool {
return slices.Contains(resp.Details.Families, "clip") return slices.Contains(resp.Details.Families, "clip")
} }
func generateInteractive(cmd *cobra.Command, opts generateOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model) multiModal := modelIsMultiModal(cmd, opts.Model)
// load the model // load the model
loadOpts := generateOptions{ loadOpts := runOptions{
Model: opts.Model, Model: opts.Model,
Prompt: "", Prompt: "",
Images: []ImageData{}, Messages: []api.Message{},
} }
if err := generate(cmd, loadOpts); err != nil { if _, err := chat(cmd, loadOpts); err != nil {
return err return err
} }
@ -141,6 +140,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
opts.Messages = make([]api.Message, 0)
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@ -409,22 +409,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
if sb.Len() > 0 && multiline == MultilineNone { if sb.Len() > 0 && multiline == MultilineNone {
opts.Prompt = sb.String() newMessage := api.Message{Role: "user", Content: sb.String()}
if multiModal { if multiModal {
newPrompt, images, err := extractFileData(sb.String()) msg, images, err := extractFileData(sb.String())
if err != nil { if err != nil {
return err return err
} }
opts.Prompt = newPrompt newMessage.Content = msg
// reset the context if we find another image // reset the context if we find another image
if len(images) > 0 { if len(images) > 0 {
opts.Images = images newMessage.Images = append(newMessage.Images, images...)
ctx := cmd.Context() // reset the context for the new image
ctx = context.WithValue(ctx, generateContextKey("context"), []int{}) opts.Messages = []api.Message{}
cmd.SetContext(ctx) } else {
if len(opts.Messages) > 1 {
newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
}
} }
if len(opts.Images) == 0 { if len(newMessage.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.") fmt.Println("This model requires you to add a jpeg, png, or svg image.")
fmt.Println() fmt.Println()
sb.Reset() sb.Reset()
@ -432,9 +436,18 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
} }
if err := generate(cmd, opts); err != nil { if opts.System != "" {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
return err return err
} }
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
sb.Reset() sb.Reset()
} }
@ -476,9 +489,9 @@ func extractFileNames(input string) []string {
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
} }
func extractFileData(input string) (string, []ImageData, error) { func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input) filePaths := extractFileNames(input)
var imgs []ImageData var imgs []api.ImageData
for _, fp := range filePaths { for _, fp := range filePaths {
nfp := normalizeFilePath(fp) nfp := normalizeFilePath(fp)