allow the user to cancel generating with ctrl-C (#641)

This commit is contained in:
Patrick Devine 2023-09-28 17:13:01 -07:00 committed by GitHub
parent 4aa0976a2e
commit 76db4a49cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -13,9 +13,11 @@ import (
"net" "net"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"syscall"
"time" "time"
"github.com/dustin/go-humanize" "github.com/dustin/go-humanize"
@ -43,7 +45,7 @@ func (p Painter) Paint(line []rune, _ int) []rune {
if p.IsMultiLine { if p.IsMultiLine {
prompt = "Use \"\"\" to end multi-line input" prompt = "Use \"\"\" to end multi-line input"
} else { } else {
prompt = "Send a message (/? for help, /bye to exit)" prompt = "Send a message (/? for help)"
} }
return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt))) return []rune(fmt.Sprintf("\033[38;5;245m%s\033[%dD\033[0m", prompt, len(prompt)))
} }
@ -426,6 +428,19 @@ func generate(cmd *cobra.Command, model, prompt string) error {
wrapTerm = false wrapTerm = false
} }
cancelCtx, cancel := context.WithCancel(context.Background())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
var abort bool
go func() {
<-sigChan
cancel()
abort = true
}()
var currentLineLength int var currentLineLength int
var wordBuffer string var wordBuffer string
@ -465,7 +480,7 @@ func generate(cmd *cobra.Command, model, prompt string) error {
return nil return nil
} }
if err := client.Generate(context.Background(), &request, fn); err != nil { if err := client.Generate(cancelCtx, &request, fn); err != nil {
if strings.Contains(err.Error(), "failed to load model") { if strings.Contains(err.Error(), "failed to load model") {
// tell the user to check the server log, if it exists locally // tell the user to check the server log, if it exists locally
home, nestedErr := os.UserHomeDir() home, nestedErr := os.UserHomeDir()
@ -477,6 +492,9 @@ func generate(cmd *cobra.Command, model, prompt string) error {
if _, nestedErr := os.Stat(logPath); nestedErr == nil { if _, nestedErr := os.Stat(logPath); nestedErr == nil {
err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath) err = fmt.Errorf("%w\nFor more details, check the error logs at %s", err, logPath)
} }
} else if strings.Contains(err.Error(), "context canceled") && abort {
spinner.Finish()
return nil
} }
return err return err
} }
@ -486,6 +504,9 @@ func generate(cmd *cobra.Command, model, prompt string) error {
} }
if !latest.Done { if !latest.Done {
if abort {
return nil
}
return errors.New("unexpected end of response") return errors.New("unexpected end of response")
} }
@ -568,7 +589,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
return nil return nil
case errors.Is(err, readline.ErrInterrupt): case errors.Is(err, readline.ErrInterrupt):
if line == "" { if line == "" {
return nil fmt.Println("Use Ctrl-D or /bye to exit.")
} }
continue continue