diff --git a/api/client.go b/api/client.go index 90002e53..f3b2ac80 100644 --- a/api/client.go +++ b/api/client.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strings" ) type Client struct { @@ -65,7 +66,6 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData if err != nil { break } - callback(bytes.TrimSuffix(line, []byte("\n"))) } @@ -128,10 +128,27 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu return &res, nil } -func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) { +func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) { var res PullResponse - if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) { - callback(string(token)) + if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) { + /* + Events have the following format for progress: + event:progress + data:{"total":123,"completed":123,"percent":0.1} + Need to parse out the data part and unmarshal it. + */ + eventParts := strings.Split(string(progressBytes), "data:") + if len(eventParts) < 2 { + // no data part, ignore + return + } + eventData := eventParts[1] + var progress PullProgress + if err := json.Unmarshal([]byte(eventData), &progress); err != nil { + fmt.Println(err) + return + } + callback(progress) }); err != nil { return nil, err } diff --git a/api/types.go b/api/types.go index e98a8d56..5ab4ba33 100644 --- a/api/types.go +++ b/api/types.go @@ -22,6 +22,12 @@ type PullRequest struct { Model string `json:"model"` } +type PullProgress struct { + Total int `json:"total"` + Completed int `json:"completed"` + Percent float64 `json:"percent"` +} + type PullResponse struct { Response string `json:"response"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 3646905b..36abd667 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,7 +7,9 @@ import ( "net" "os" "path" + "sync" + "github.com/gosuri/uiprogress" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/server" "github.com/spf13/cobra" @@ -22,6 +24,10 @@ func cacheDir() string { return path.Join(home, ".ollama") } +func bytesToGB(bytes int) float64 { + return float64(bytes) / float64(1<<30) +} + func run(model string) error { client, err := NewAPIClient() if err != nil { @@ -30,8 +36,29 @@ func run(model string) error { pr := api.PullRequest{ Model: model, } - callback := func(progress string) { - fmt.Println(progress) + var bar *uiprogress.Bar + mutex := &sync.Mutex{} + var progressData api.PullProgress + + callback := func(progress api.PullProgress) { + mutex.Lock() + progressData = progress + if bar == nil { + uiprogress.Start() // start rendering + bar = uiprogress.AddBar(int(progress.Total)) // Add a new bar + + // display the total file size and how much has downloaded so far + bar.PrependFunc(func(b *uiprogress.Bar) string { + return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total)) + }) + + // display completion percentage + bar.AppendFunc(func(b *uiprogress.Bar) string { + return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100)) + }) + } + bar.Set(int(progress.Completed)) + mutex.Unlock() } _, err = client.Pull(context.Background(), &pr, callback) return err diff --git a/go.mod b/go.mod index ece81ae2..6ca336d1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/gin-gonic/gin v1.9.1 + github.com/gosuri/uiprogress v0.0.1 github.com/spf13/cobra v1.7.0 ) @@ -17,6 +18,7 @@ require ( github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.5.9 // indirect + github.com/gosuri/uilive v0.0.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect @@ -32,7 +34,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.9.0 // indirect + golang.org/x/sys v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index ef8f65ec..065bb0db 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,10 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY= +github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI= +github.com/gosuri/uiprogress v0.0.1 h1:0kpv/XY/qTmFWl/SkaJykZXrBBzwwadmW8fRb7RJSxw= +github.com/gosuri/uiprogress v0.0.1/go.mod h1:C1RTYn4Sc7iEyf6j8ft5dyoZ4212h8G1ol9QQluh5+0= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -97,8 +101,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/server/models.go b/server/models.go index d5504390..cfa04002 100644 --- a/server/models.go +++ b/server/models.go @@ -9,15 +9,14 @@ import ( "os" "path" "strconv" + + "github.com/jmorganca/ollama/api" ) // const directoryURL = "https://ollama.ai/api/models" +// TODO const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json" -type directoryCtxKey string - -var dirCtx directoryCtxKey = "directory" - type Model struct { Name string `json:"name"` DisplayName string `json:"display_name"` @@ -31,7 +30,7 @@ type Model struct { License string `json:"license"` } -func pull(model string, progressCh chan<- string) error { +func pull(model string, progressCh chan<- api.PullProgress) error { remote, err := getRemote(model) if err != nil { return fmt.Errorf("failed to pull model: %w", err) @@ -64,7 +63,7 @@ func getRemote(model string) (*Model, error) { return nil, fmt.Errorf("model not found in directory: %s", model) } -func saveModel(model *Model, progressCh chan<- string) error { +func saveModel(model *Model, progressCh chan<- api.PullProgress) error { // this models cache directory is created by the server on startup home, err := os.UserHomeDir() if err != nil { @@ -130,11 +129,18 @@ func saveModel(model *Model, progressCh chan<- string) error { totalBytes += n // send progress updates - progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100) + progressCh <- api.PullProgress{ + Total: totalSize, + Completed: totalBytes, + Percent: float64(totalBytes) / float64(totalSize) * 100, + } } - // send completion message - progressCh <- "Download complete!" + progressCh <- api.PullProgress{ + Total: totalSize, + Completed: totalSize, + Percent: 100, + } return nil } diff --git a/server/routes.go b/server/routes.go index 922e4f8a..16c50029 100644 --- a/server/routes.go +++ b/server/routes.go @@ -107,7 +107,7 @@ func Serve(ln net.Listener) error { return } - progressCh := make(chan string) + progressCh := make(chan api.PullProgress) go func() { defer close(progressCh) if err := pull(req.Model, progressCh); err != nil {