cmd: spinner progress for transfer model data (#6100)

This commit is contained in:
Josh 2024-08-12 11:46:32 -07:00 committed by GitHub
parent 980dd15f81
commit f7e3b9190f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 52 additions and 7 deletions

View file

@ -22,6 +22,7 @@ import (
"runtime" "runtime"
"slices" "slices"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data" status := "transferring model data"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, client, path, spinner)
if err != nil { if err != nil {
return err return err
} }
@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer bin.Close() defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New() hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil { if _, err := io.Copy(hash, bin); err != nil {
return "", err return "", err
@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err return "", err
} }
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
type progressWriter struct {
n atomic.Int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n.Add(int64(len(p)))
return len(p), nil
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true

View file

@ -3,11 +3,12 @@ package progress
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
type Spinner struct { type Spinner struct {
message string message atomic.Value
messageWidth int messageWidth int
parts []string parts []string
@ -21,20 +22,25 @@ type Spinner struct {
func NewSpinner(message string) *Spinner { func NewSpinner(message string) *Spinner {
s := &Spinner{ s := &Spinner{
message: message,
parts: []string{ parts: []string{
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
}, },
started: time.Now(), started: time.Now(),
} }
s.SetMessage(message)
go s.start() go s.start()
return s return s
} }
func (s *Spinner) SetMessage(message string) {
s.message.Store(message)
}
func (s *Spinner) String() string { func (s *Spinner) String() string {
var sb strings.Builder var sb strings.Builder
if len(s.message) > 0 {
message := strings.TrimSpace(s.message) if message, ok := s.message.Load().(string); ok && len(message) > 0 {
message := strings.TrimSpace(message)
if s.messageWidth > 0 && len(message) > s.messageWidth { if s.messageWidth > 0 && len(message) > s.messageWidth {
message = message[:s.messageWidth] message = message[:s.messageWidth]
} }