diff --git a/cmd/cmd.go b/cmd/cmd.go index 4df3b004..bc128d74 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -472,6 +472,13 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st return err } + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + spinner := progress.NewSpinner("") + defer spinner.Stop() + p.Add("", spinner) + var latest api.GenerateResponse generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) @@ -502,6 +509,9 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format} fn := func(response api.GenerateResponse) error { + spinner.Stop() + p.StopAndClear() + latest = response if wordWrap { diff --git a/progress/progress.go b/progress/progress.go index 5002b8d2..0aa1e911 100644 --- a/progress/progress.go +++ b/progress/progress.go @@ -3,8 +3,12 @@ package progress import ( "fmt" "io" + "os" + "strings" "sync" "time" + + "golang.org/x/term" ) type State interface { @@ -26,12 +30,34 @@ func NewProgress(w io.Writer) *Progress { return p } -func (p *Progress) Stop() { +func (p *Progress) Stop() bool { if p.ticker != nil { p.ticker.Stop() p.ticker = nil p.render() + return true } + + return false +} + +func (p *Progress) StopAndClear() bool { + stopped := p.Stop() + if stopped { + termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) + if err != nil { + panic(err) + } + + // clear the progress bar by: + // 1. reset to beginning of line + // 2. move up to the first line of the progress bar + // 3. fill the terminal width with spaces + // 4. reset to beginning of line + fmt.Fprintf(p.w, "\r\033[%dA%s\r", p.pos, strings.Repeat(" ", termWidth)) + } + + return stopped } func (p *Progress) Add(key string, state State) {