diff --git a/progress/bar.go b/progress/bar.go new file mode 100644 index 00000000..9a81fd0f --- /dev/null +++ b/progress/bar.go @@ -0,0 +1,118 @@ +package progress + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/jmorganca/ollama/format" + "golang.org/x/term" +) + +type Bar struct { + message string + messageWidth int + + maxValue int64 + initialValue int64 + currentValue int64 + + started time.Time + stopped time.Time +} + +func NewBar(message string, maxValue, initialValue int64) *Bar { + return &Bar{ + message: message, + messageWidth: -1, + maxValue: maxValue, + initialValue: initialValue, + currentValue: initialValue, + started: time.Now(), + } +} + +func (b *Bar) String() string { + termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) + if err != nil { + panic(err) + } + + var pre, mid, suf strings.Builder + + if b.message != "" { + message := strings.TrimSpace(b.message) + if b.messageWidth > 0 && len(message) > b.messageWidth { + message = message[:b.messageWidth] + } + + fmt.Fprintf(&pre, "%s", message) + if b.messageWidth-pre.Len() >= 0 { + pre.WriteString(strings.Repeat(" ", b.messageWidth-pre.Len())) + } + + pre.WriteString(" ") + } + + fmt.Fprintf(&pre, "%.1f%% ", b.percent()) + + fmt.Fprintf(&suf, "(%s/%s, %s/s, %s)", + format.HumanBytes(b.currentValue), + format.HumanBytes(b.maxValue), + format.HumanBytes(int64(b.rate())), + b.elapsed()) + + mid.WriteString("[") + + // pad 3 for last = or > and "] " + f := termWidth - pre.Len() - mid.Len() - suf.Len() - 3 + n := int(float64(f) * b.percent() / 100) + if n > 0 { + mid.WriteString(strings.Repeat("=", n)) + } + + if b.currentValue >= b.maxValue { + mid.WriteString("=") + } else { + mid.WriteString(">") + } + + if f-n > 0 { + mid.WriteString(strings.Repeat(" ", f-n)) + } + + mid.WriteString("] ") + + return pre.String() + mid.String() + suf.String() +} + +func (b *Bar) Set(value int64) { + if value >= b.maxValue { + value = b.maxValue + b.stopped = time.Now() + } + + b.currentValue = value +} + +func (b *Bar) percent() float64 { + if b.maxValue > 0 { + return float64(b.currentValue) / float64(b.maxValue) * 100 + } + + return 0 +} + +func (b *Bar) rate() float64 { + return (float64(b.currentValue) - float64(b.initialValue)) / b.elapsed().Seconds() +} + +func (b *Bar) elapsed() time.Duration { + stopped := b.stopped + if stopped.IsZero() { + stopped = time.Now() + } + + return stopped.Sub(b.started).Round(time.Second) +} diff --git a/progress/progress.go b/progress/progress.go new file mode 100644 index 00000000..5002b8d2 --- /dev/null +++ b/progress/progress.go @@ -0,0 +1,65 @@ +package progress + +import ( + "fmt" + "io" + "sync" + "time" +) + +type State interface { + String() string +} + +type Progress struct { + mu sync.Mutex + pos int + w io.Writer + + ticker *time.Ticker + states []State +} + +func NewProgress(w io.Writer) *Progress { + p := &Progress{pos: -1, w: w} + go p.start() + return p +} + +func (p *Progress) Stop() { + if p.ticker != nil { + p.ticker.Stop() + p.ticker = nil + p.render() + } +} + +func (p *Progress) Add(key string, state State) { + p.mu.Lock() + defer p.mu.Unlock() + + p.states = append(p.states, state) +} + +func (p *Progress) render() error { + p.mu.Lock() + defer p.mu.Unlock() + + fmt.Fprintf(p.w, "\033[%dA", p.pos) + for _, state := range p.states { + fmt.Fprintln(p.w, state.String()) + } + + if len(p.states) > 0 { + p.pos = len(p.states) + } + + return nil +} + +func (p *Progress) start() { + p.ticker = time.NewTicker(100 * time.Millisecond) + for range p.ticker.C { + p.render() + } +} diff --git a/progress/spinner.go b/progress/spinner.go new file mode 100644 index 00000000..bc46bc02 --- /dev/null +++ b/progress/spinner.go @@ -0,0 +1,102 @@ +package progress + +import ( + "fmt" + "os" + "strings" + "time" + + "golang.org/x/term" +) + +type Spinner struct { + message string + messageWidth int + + parts []string + + value int + + ticker *time.Ticker + started time.Time + stopped time.Time +} + +func NewSpinner(message string) *Spinner { + s := &Spinner{ + message: message, + parts: []string{ + "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", + }, + started: time.Now(), + } + go s.start() + return s +} + +func (s *Spinner) String() string { + termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) + if err != nil { + panic(err) + } + + var pre strings.Builder + if len(s.message) > 0 { + message := strings.TrimSpace(s.message) + if s.messageWidth > 0 && len(message) > s.messageWidth { + message = message[:s.messageWidth] + } + + fmt.Fprintf(&pre, "%s", message) + if s.messageWidth-pre.Len() >= 0 { + pre.WriteString(strings.Repeat(" ", s.messageWidth-pre.Len())) + } + + pre.WriteString(" ") + } + + var pad int + if s.stopped.IsZero() { + // spinner has a string length of 3 but a rune length of 1 + // in order to align correctly, we need to pad with (3 - 1) = 2 spaces + spinner := s.parts[s.value] + pre.WriteString(spinner) + pad = len(spinner) - len([]rune(spinner)) + } + + var suf strings.Builder + fmt.Fprintf(&suf, "(%s)", s.elapsed()) + + var mid strings.Builder + f := termWidth - pre.Len() - mid.Len() - suf.Len() + pad + if f > 0 { + mid.WriteString(strings.Repeat(" ", f)) + } + + return pre.String() + mid.String() + suf.String() +} + +func (s *Spinner) start() { + s.ticker = time.NewTicker(100 * time.Millisecond) + for range s.ticker.C { + s.value = (s.value + 1) % len(s.parts) + if !s.stopped.IsZero() { + return + } + } +} + +func (s *Spinner) Stop() { + if s.stopped.IsZero() { + s.stopped = time.Now() + } +} + +func (s *Spinner) elapsed() time.Duration { + stopped := s.stopped + if stopped.IsZero() { + stopped = time.Now() + } + + return stopped.Sub(s.started).Round(time.Second) +}