Convert the REPL to use /api/chat for interactive responses (#1936)
This commit is contained in:
parent
40a0a90a88
commit
565f8a3c44
2 changed files with 155 additions and 72 deletions
178
cmd/cmd.go
178
cmd/cmd.go
|
@ -35,8 +35,6 @@ import (
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ImageData []byte
|
|
||||||
|
|
||||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
filename, _ := cmd.Flags().GetString("file")
|
filename, _ := cmd.Flags().GetString("file")
|
||||||
filename, err := filepath.Abs(filename)
|
filename, err := filepath.Abs(filename)
|
||||||
|
@ -415,11 +413,10 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
opts := generateOptions{
|
opts := runOptions{
|
||||||
Model: args[0],
|
Model: args[0],
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
Options: map[string]interface{}{},
|
Options: map[string]interface{}{},
|
||||||
Images: []ImageData{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
|
@ -460,18 +457,135 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
type generateContextKey string
|
type generateContextKey string
|
||||||
|
|
||||||
type generateOptions struct {
|
type runOptions struct {
|
||||||
Model string
|
Model string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
Messages []api.Message
|
||||||
WordWrap bool
|
WordWrap bool
|
||||||
Format string
|
Format string
|
||||||
System string
|
System string
|
||||||
Template string
|
Template string
|
||||||
Images []ImageData
|
Images []api.ImageData
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, opts generateOptions) error {
|
type displayResponseState struct {
|
||||||
|
lineLength int
|
||||||
|
wordBuffer string
|
||||||
|
}
|
||||||
|
|
||||||
|
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||||
|
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||||
|
if wordWrap && termWidth >= 10 {
|
||||||
|
for _, ch := range content {
|
||||||
|
if state.lineLength+1 > termWidth-5 {
|
||||||
|
if len(state.wordBuffer) > termWidth-10 {
|
||||||
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
|
state.wordBuffer = ""
|
||||||
|
state.lineLength = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// backtrack the length of the last word and clear to the end of the line
|
||||||
|
fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer))
|
||||||
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
|
state.lineLength = len(state.wordBuffer) + 1
|
||||||
|
} else {
|
||||||
|
fmt.Print(string(ch))
|
||||||
|
state.lineLength += 1
|
||||||
|
|
||||||
|
switch ch {
|
||||||
|
case ' ':
|
||||||
|
state.wordBuffer = ""
|
||||||
|
case '\n':
|
||||||
|
state.lineLength = 0
|
||||||
|
default:
|
||||||
|
state.wordBuffer += string(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%s%s", state.wordBuffer, content)
|
||||||
|
if len(state.wordBuffer) > 0 {
|
||||||
|
state.wordBuffer = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.StopAndClear()
|
||||||
|
|
||||||
|
spinner := progress.NewSpinner("")
|
||||||
|
p.Add("", spinner)
|
||||||
|
|
||||||
|
cancelCtx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-sigChan
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var state *displayResponseState = &displayResponseState{}
|
||||||
|
var latest api.ChatResponse
|
||||||
|
var fullResponse strings.Builder
|
||||||
|
var role string
|
||||||
|
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
p.StopAndClear()
|
||||||
|
|
||||||
|
latest = response
|
||||||
|
|
||||||
|
role = response.Message.Role
|
||||||
|
content := response.Message.Content
|
||||||
|
fullResponse.WriteString(content)
|
||||||
|
|
||||||
|
displayResponse(content, opts.WordWrap, state)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.ChatRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Messages: opts.Messages,
|
||||||
|
Format: opts.Format,
|
||||||
|
Options: opts.Options,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.Messages) > 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
verbose, err := cmd.Flags().GetBool("verbose")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
latest.Summary()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -490,11 +604,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
generateContext = []int{}
|
generateContext = []int{}
|
||||||
}
|
}
|
||||||
|
|
||||||
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
|
||||||
if err != nil {
|
|
||||||
opts.WordWrap = false
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
@ -506,57 +615,19 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var currentLineLength int
|
var state *displayResponseState = &displayResponseState{}
|
||||||
var wordBuffer string
|
|
||||||
|
|
||||||
fn := func(response api.GenerateResponse) error {
|
fn := func(response api.GenerateResponse) error {
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
|
||||||
latest = response
|
latest = response
|
||||||
|
content := response.Response
|
||||||
|
|
||||||
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
|
displayResponse(content, opts.WordWrap, state)
|
||||||
if opts.WordWrap && termWidth >= 10 {
|
|
||||||
for _, ch := range response.Response {
|
|
||||||
if currentLineLength+1 > termWidth-5 {
|
|
||||||
if len(wordBuffer) > termWidth-10 {
|
|
||||||
fmt.Printf("%s%c", wordBuffer, ch)
|
|
||||||
wordBuffer = ""
|
|
||||||
currentLineLength = 0
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// backtrack the length of the last word and clear to the end of the line
|
|
||||||
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
|
|
||||||
fmt.Printf("%s%c", wordBuffer, ch)
|
|
||||||
currentLineLength = len(wordBuffer) + 1
|
|
||||||
} else {
|
|
||||||
fmt.Print(string(ch))
|
|
||||||
currentLineLength += 1
|
|
||||||
|
|
||||||
switch ch {
|
|
||||||
case ' ':
|
|
||||||
wordBuffer = ""
|
|
||||||
case '\n':
|
|
||||||
currentLineLength = 0
|
|
||||||
default:
|
|
||||||
wordBuffer += string(ch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Printf("%s%s", wordBuffer, response.Response)
|
|
||||||
if len(wordBuffer) > 0 {
|
|
||||||
wordBuffer = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
images := make([]api.ImageData, 0)
|
|
||||||
for _, i := range opts.Images {
|
|
||||||
images = append(images, api.ImageData(i))
|
|
||||||
}
|
|
||||||
request := api.GenerateRequest{
|
request := api.GenerateRequest{
|
||||||
Model: opts.Model,
|
Model: opts.Model,
|
||||||
Prompt: opts.Prompt,
|
Prompt: opts.Prompt,
|
||||||
|
@ -565,7 +636,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
System: opts.System,
|
System: opts.System,
|
||||||
Template: opts.Template,
|
Template: opts.Template,
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
Images: images,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -43,16 +42,16 @@ func modelIsMultiModal(cmd *cobra.Command, name string) bool {
|
||||||
return slices.Contains(resp.Details.Families, "clip")
|
return slices.Contains(resp.Details.Families, "clip")
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
multiModal := modelIsMultiModal(cmd, opts.Model)
|
multiModal := modelIsMultiModal(cmd, opts.Model)
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
loadOpts := generateOptions{
|
loadOpts := runOptions{
|
||||||
Model: opts.Model,
|
Model: opts.Model,
|
||||||
Prompt: "",
|
Prompt: "",
|
||||||
Images: []ImageData{},
|
Messages: []api.Message{},
|
||||||
}
|
}
|
||||||
if err := generate(cmd, loadOpts); err != nil {
|
if _, err := chat(cmd, loadOpts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +140,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var multiline MultilineState
|
var multiline MultilineState
|
||||||
|
opts.Messages = make([]api.Message, 0)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
|
@ -409,22 +409,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if sb.Len() > 0 && multiline == MultilineNone {
|
if sb.Len() > 0 && multiline == MultilineNone {
|
||||||
opts.Prompt = sb.String()
|
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||||
|
|
||||||
if multiModal {
|
if multiModal {
|
||||||
newPrompt, images, err := extractFileData(sb.String())
|
msg, images, err := extractFileData(sb.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts.Prompt = newPrompt
|
newMessage.Content = msg
|
||||||
|
|
||||||
// reset the context if we find another image
|
// reset the context if we find another image
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
opts.Images = images
|
newMessage.Images = append(newMessage.Images, images...)
|
||||||
ctx := cmd.Context()
|
// reset the context for the new image
|
||||||
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
|
opts.Messages = []api.Message{}
|
||||||
cmd.SetContext(ctx)
|
} else {
|
||||||
|
if len(opts.Messages) > 1 {
|
||||||
|
newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
|
||||||
}
|
}
|
||||||
if len(opts.Images) == 0 {
|
}
|
||||||
|
if len(newMessage.Images) == 0 {
|
||||||
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
|
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
|
@ -432,9 +436,18 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := generate(cmd, opts); err != nil {
|
if opts.System != "" {
|
||||||
|
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||||
|
}
|
||||||
|
opts.Messages = append(opts.Messages, newMessage)
|
||||||
|
|
||||||
|
assistant, err := chat(cmd, opts)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if assistant != nil {
|
||||||
|
opts.Messages = append(opts.Messages, *assistant)
|
||||||
|
}
|
||||||
|
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
}
|
}
|
||||||
|
@ -476,9 +489,9 @@ func extractFileNames(input string) []string {
|
||||||
return re.FindAllString(input, -1)
|
return re.FindAllString(input, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractFileData(input string) (string, []ImageData, error) {
|
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||||
filePaths := extractFileNames(input)
|
filePaths := extractFileNames(input)
|
||||||
var imgs []ImageData
|
var imgs []api.ImageData
|
||||||
|
|
||||||
for _, fp := range filePaths {
|
for _, fp := range filePaths {
|
||||||
nfp := normalizeFilePath(fp)
|
nfp := normalizeFilePath(fp)
|
||||||
|
|
Loading…
Reference in a new issue