From 1c0e092eadc9f56abd745d31ff5c57e91fddd45e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 16:33:24 -0800 Subject: [PATCH] progress cmd --- cmd/cmd.go | 132 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 12 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 5d31eae2..4df3b004 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -30,6 +30,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/parser" + "github.com/jmorganca/ollama/progress" "github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/server" "github.com/jmorganca/ollama/version" @@ -47,6 +48,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + bars := make(map[string]*progress.Bar) + + status := "transferring context" + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + modelfile, err := os.ReadFile(filename) if err != nil { return err @@ -95,16 +105,38 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } } - request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)} fn := func(resp api.ProgressResponse) error { - log.Printf("progress(%s): %s", resp.Digest, resp.Status) + if resp.Digest != "" { + spinner.Stop() + + bar, ok := bars[resp.Digest] + if !ok { + bar = progress.NewBar(resp.Status, resp.Total, resp.Completed) + bars[resp.Digest] = bar + p.Add(resp.Digest, bar) + } + + bar.Set(resp.Completed) + } else if status != resp.Status { + spinner.Stop() + + status = resp.Status + spinner = progress.NewSpinner(status) + p.Add(status, spinner) + } + return nil } + request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)} if err := client.Create(context.Background(), &request, fn); err != nil { return err } + if spinner != nil { + spinner.Stop() + } + return nil } @@ -141,13 +173,53 @@ func PushHandler(cmd *cobra.Command, args []string) error { return err } - request := api.PushRequest{Name: args[0], Insecure: insecure} + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + bars := make(map[string]*progress.Bar) + + var status string + var spinner *progress.Spinner + fn := func(resp api.ProgressResponse) error { - log.Printf("progress(%s): %s", resp.Digest, resp.Status) + if resp.Digest != "" { + if spinner != nil { + spinner.Stop() + spinner = nil + } + + bar, ok := bars[resp.Digest] + if !ok { + bar = progress.NewBar(resp.Status, resp.Total, resp.Completed) + bars[resp.Digest] = bar + p.Add(resp.Digest, bar) + } + + bar.Set(resp.Completed) + } else if status != resp.Status { + if spinner != nil { + spinner.Stop() + spinner = nil + } + + status = resp.Status + spinner = progress.NewSpinner(status) + p.Add(status, spinner) + } + return nil } - return client.Push(context.Background(), &request, fn) + request := api.PushRequest{Name: args[0], Insecure: insecure} + if err := client.Push(context.Background(), &request, fn); err != nil { + return err + } + + if spinner != nil { + spinner.Stop() + } + + return nil } func ListHandler(cmd *cobra.Command, args []string) error { @@ -297,22 +369,58 @@ func PullHandler(cmd *cobra.Command, args []string) error { return err } - return pull(args[0], insecure) -} - -func pull(model string, insecure bool) error { client, err := api.ClientFromEnvironment() if err != nil { return err } - request := api.PullRequest{Name: model, Insecure: insecure} + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + bars := make(map[string]*progress.Bar) + + var status string + var spinner *progress.Spinner + fn := func(resp api.ProgressResponse) error { - log.Printf("progress(%s): %s", resp.Digest, resp.Status) + if resp.Digest != "" { + if spinner != nil { + spinner.Stop() + spinner = nil + } + + bar, ok := bars[resp.Digest] + if !ok { + bar = progress.NewBar(resp.Status, resp.Total, resp.Completed) + bars[resp.Digest] = bar + p.Add(resp.Digest, bar) + } + + bar.Set(resp.Completed) + } else if status != resp.Status { + if spinner != nil { + spinner.Stop() + spinner = nil + } + + status = resp.Status + spinner = progress.NewSpinner(status) + p.Add(status, spinner) + } + return nil } - return client.Pull(context.Background(), &request, fn) + request := api.PullRequest{Name: args[0], Insecure: insecure} + if err := client.Pull(context.Background(), &request, fn); err != nil { + return err + } + + if spinner != nil { + spinner.Stop() + } + + return nil } func RunGenerate(cmd *cobra.Command, args []string) error {