progress cmd

This commit is contained in:
Michael Yang 2023-11-14 16:33:24 -08:00
parent c4a3ccd7ac
commit 1c0e092ead

View file

@ -30,6 +30,7 @@ import (
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/progress"
"github.com/jmorganca/ollama/readline"
"github.com/jmorganca/ollama/server"
"github.com/jmorganca/ollama/version"
@ -47,6 +48,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
status := "transferring context"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
modelfile, err := os.ReadFile(filename)
if err != nil {
return err
@ -95,16 +105,38 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
}
request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
fn := func(resp api.ProgressResponse) error {
log.Printf("progress(%s): %s", resp.Digest, resp.Status)
if resp.Digest != "" {
spinner.Stop()
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(resp.Status, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
if err := client.Create(context.Background(), &request, fn); err != nil {
return err
}
if spinner != nil {
spinner.Stop()
}
return nil
}
@ -141,13 +173,53 @@ func PushHandler(cmd *cobra.Command, args []string) error {
return err
}
request := api.PushRequest{Name: args[0], Insecure: insecure}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
log.Printf("progress(%s): %s", resp.Digest, resp.Status)
if resp.Digest != "" {
if spinner != nil {
spinner.Stop()
spinner = nil
}
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(resp.Status, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
spinner = nil
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
return client.Push(context.Background(), &request, fn)
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(context.Background(), &request, fn); err != nil {
return err
}
if spinner != nil {
spinner.Stop()
}
return nil
}
func ListHandler(cmd *cobra.Command, args []string) error {
@ -297,22 +369,58 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return err
}
return pull(args[0], insecure)
}
func pull(model string, insecure bool) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
request := api.PullRequest{Name: model, Insecure: insecure}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
log.Printf("progress(%s): %s", resp.Digest, resp.Status)
if resp.Digest != "" {
if spinner != nil {
spinner.Stop()
spinner = nil
}
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(resp.Status, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
spinner = nil
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
return client.Pull(context.Background(), &request, fn)
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(context.Background(), &request, fn); err != nil {
return err
}
if spinner != nil {
spinner.Stop()
}
return nil
}
func RunGenerate(cmd *cobra.Command, args []string) error {