display pull progress

This commit is contained in:
Bruce MacDonald 2023-07-06 14:18:40 -04:00 committed by Jeffrey Morgan
parent 580fe8951c
commit 7cf5905063
7 changed files with 81 additions and 19 deletions

View file

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
)
type Client struct {
@ -65,7 +66,6 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
if err != nil {
break
}
callback(bytes.TrimSuffix(line, []byte("\n")))
}
@ -128,10 +128,27 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
return &res, nil
}
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) {
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
var res PullResponse
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) {
callback(string(token))
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
/*
Events have the following format for progress:
event:progress
data:{"total":123,"completed":123,"percent":0.1}
Need to parse out the data part and unmarshal it.
*/
eventParts := strings.Split(string(progressBytes), "data:")
if len(eventParts) < 2 {
// no data part, ignore
return
}
eventData := eventParts[1]
var progress PullProgress
if err := json.Unmarshal([]byte(eventData), &progress); err != nil {
fmt.Println(err)
return
}
callback(progress)
}); err != nil {
return nil, err
}

View file

@ -22,6 +22,12 @@ type PullRequest struct {
Model string `json:"model"`
}
type PullProgress struct {
Total int `json:"total"`
Completed int `json:"completed"`
Percent float64 `json:"percent"`
}
type PullResponse struct {
Response string `json:"response"`
}

View file

@ -7,7 +7,9 @@ import (
"net"
"os"
"path"
"sync"
"github.com/gosuri/uiprogress"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/server"
"github.com/spf13/cobra"
@ -22,6 +24,10 @@ func cacheDir() string {
return path.Join(home, ".ollama")
}
func bytesToGB(bytes int) float64 {
return float64(bytes) / float64(1<<30)
}
func run(model string) error {
client, err := NewAPIClient()
if err != nil {
@ -30,8 +36,29 @@ func run(model string) error {
pr := api.PullRequest{
Model: model,
}
callback := func(progress string) {
fmt.Println(progress)
var bar *uiprogress.Bar
mutex := &sync.Mutex{}
var progressData api.PullProgress
callback := func(progress api.PullProgress) {
mutex.Lock()
progressData = progress
if bar == nil {
uiprogress.Start() // start rendering
bar = uiprogress.AddBar(int(progress.Total)) // Add a new bar
// display the total file size and how much has downloaded so far
bar.PrependFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
})
// display completion percentage
bar.AppendFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
})
}
bar.Set(int(progress.Completed))
mutex.Unlock()
}
_, err = client.Pull(context.Background(), &pr, callback)
return err

4
go.mod
View file

@ -4,6 +4,7 @@ go 1.20
require (
github.com/gin-gonic/gin v1.9.1
github.com/gosuri/uiprogress v0.0.1
github.com/spf13/cobra v1.7.0
)
@ -17,6 +18,7 @@ require (
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/gosuri/uilive v0.0.4 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
@ -32,7 +34,7 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.10.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.9.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.10.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View file

@ -28,6 +28,10 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY=
github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI=
github.com/gosuri/uiprogress v0.0.1 h1:0kpv/XY/qTmFWl/SkaJykZXrBBzwwadmW8fRb7RJSxw=
github.com/gosuri/uiprogress v0.0.1/go.mod h1:C1RTYn4Sc7iEyf6j8ft5dyoZ4212h8G1ol9QQluh5+0=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
@ -97,8 +101,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View file

@ -9,15 +9,14 @@ import (
"os"
"path"
"strconv"
"github.com/jmorganca/ollama/api"
)
// const directoryURL = "https://ollama.ai/api/models"
// TODO
const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
type directoryCtxKey string
var dirCtx directoryCtxKey = "directory"
type Model struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
@ -31,7 +30,7 @@ type Model struct {
License string `json:"license"`
}
func pull(model string, progressCh chan<- string) error {
func pull(model string, progressCh chan<- api.PullProgress) error {
remote, err := getRemote(model)
if err != nil {
return fmt.Errorf("failed to pull model: %w", err)
@ -64,7 +63,7 @@ func getRemote(model string) (*Model, error) {
return nil, fmt.Errorf("model not found in directory: %s", model)
}
func saveModel(model *Model, progressCh chan<- string) error {
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
// this models cache directory is created by the server on startup
home, err := os.UserHomeDir()
if err != nil {
@ -130,11 +129,18 @@ func saveModel(model *Model, progressCh chan<- string) error {
totalBytes += n
// send progress updates
progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalBytes,
Percent: float64(totalBytes) / float64(totalSize) * 100,
}
}
// send completion message
progressCh <- "Download complete!"
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalSize,
Percent: 100,
}
return nil
}

View file

@ -107,7 +107,7 @@ func Serve(ln net.Listener) error {
return
}
progressCh := make(chan string)
progressCh := make(chan api.PullProgress)
go func() {
defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil {