Merge pull request #92 from jmorganca/create-model-spinner

Create model spinner
This commit is contained in:
Michael Yang 2023-07-18 11:15:45 -07:00 committed by GitHub
commit 885f67a471
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 33 deletions

View file

@ -9,7 +9,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
@ -24,22 +23,21 @@ import (
"github.com/jmorganca/ollama/server" "github.com/jmorganca/ollama/server"
) )
func cacheDir() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return filepath.Join(home, ".ollama")
}
func create(cmd *cobra.Command, args []string) error { func create(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file") filename, _ := cmd.Flags().GetString("file")
client := api.NewClient() client := api.NewClient()
var spinner *Spinner
request := api.CreateRequest{Name: args[0], Path: filename} request := api.CreateRequest{Name: args[0], Path: filename}
fn := func(resp api.CreateProgress) error { fn := func(resp api.CreateProgress) error {
fmt.Println(resp.Status) if spinner != nil {
spinner.Stop()
}
spinner = NewSpinner(resp.Status)
go spinner.Spin(100 * time.Millisecond)
return nil return nil
} }
@ -47,6 +45,10 @@ func create(cmd *cobra.Command, args []string) error {
return err return err
} }
if spinner != nil {
spinner.Stop()
}
return nil return nil
} }
@ -176,24 +178,8 @@ func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 { if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient() client := api.NewClient()
spinner := progressbar.NewOptions(-1, spinner := NewSpinner("")
progressbar.OptionSetWriter(os.Stderr), go spinner.Spin(60 * time.Millisecond)
progressbar.OptionThrottle(60*time.Millisecond),
progressbar.OptionSpinnerType(14),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetElapsedTime(false),
progressbar.OptionClearOnFinish(),
)
go func() {
for range time.Tick(60 * time.Millisecond) {
if spinner.IsFinished() {
break
}
spinner.Add(1)
}
}()
var latest api.GenerateResponse var latest api.GenerateResponse
@ -292,10 +278,6 @@ func NewCLI() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{ CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true, DisableDefaultCmd: true,
}, },
PersistentPreRunE: func(_ *cobra.Command, args []string) error {
// create the models directory and it's parent
return os.MkdirAll(filepath.Join(cacheDir(), "models"), 0o700)
},
} }
cobra.EnableCommandSorting = false cobra.EnableCommandSorting = false

44
cmd/spinner.go Normal file
View file

@ -0,0 +1,44 @@
package cmd
import (
"fmt"
"os"
"time"
"github.com/schollz/progressbar/v3"
)
type Spinner struct {
description string
*progressbar.ProgressBar
}
func NewSpinner(description string) *Spinner {
return &Spinner{
description: description,
ProgressBar: progressbar.NewOptions(-1,
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionThrottle(60*time.Millisecond),
progressbar.OptionSpinnerType(14),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetElapsedTime(false),
progressbar.OptionClearOnFinish(),
progressbar.OptionSetDescription(description),
),
}
}
func (s *Spinner) Spin(tick time.Duration) {
for range time.Tick(tick) {
if s.IsFinished() {
break
}
s.Add(1)
}
}
func (s *Spinner) Stop() {
s.Finish()
fmt.Println(s.description)
}