display pull progress
This commit is contained in:
parent
580fe8951c
commit
7cf5905063
7 changed files with 81 additions and 19 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
|
@ -65,7 +66,6 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
|
|||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
callback(bytes.TrimSuffix(line, []byte("\n")))
|
||||
}
|
||||
|
||||
|
@ -128,10 +128,27 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
|
|||
return &res, nil
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) {
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
|
||||
var res PullResponse
|
||||
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) {
|
||||
callback(string(token))
|
||||
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
|
||||
/*
|
||||
Events have the following format for progress:
|
||||
event:progress
|
||||
data:{"total":123,"completed":123,"percent":0.1}
|
||||
Need to parse out the data part and unmarshal it.
|
||||
*/
|
||||
eventParts := strings.Split(string(progressBytes), "data:")
|
||||
if len(eventParts) < 2 {
|
||||
// no data part, ignore
|
||||
return
|
||||
}
|
||||
eventData := eventParts[1]
|
||||
var progress PullProgress
|
||||
if err := json.Unmarshal([]byte(eventData), &progress); err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
callback(progress)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -22,6 +22,12 @@ type PullRequest struct {
|
|||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type PullProgress struct {
|
||||
Total int `json:"total"`
|
||||
Completed int `json:"completed"`
|
||||
Percent float64 `json:"percent"`
|
||||
}
|
||||
|
||||
type PullResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
|
31
cmd/cmd.go
31
cmd/cmd.go
|
@ -7,7 +7,9 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"github.com/gosuri/uiprogress"
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/server"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -22,6 +24,10 @@ func cacheDir() string {
|
|||
return path.Join(home, ".ollama")
|
||||
}
|
||||
|
||||
func bytesToGB(bytes int) float64 {
|
||||
return float64(bytes) / float64(1<<30)
|
||||
}
|
||||
|
||||
func run(model string) error {
|
||||
client, err := NewAPIClient()
|
||||
if err != nil {
|
||||
|
@ -30,8 +36,29 @@ func run(model string) error {
|
|||
pr := api.PullRequest{
|
||||
Model: model,
|
||||
}
|
||||
callback := func(progress string) {
|
||||
fmt.Println(progress)
|
||||
var bar *uiprogress.Bar
|
||||
mutex := &sync.Mutex{}
|
||||
var progressData api.PullProgress
|
||||
|
||||
callback := func(progress api.PullProgress) {
|
||||
mutex.Lock()
|
||||
progressData = progress
|
||||
if bar == nil {
|
||||
uiprogress.Start() // start rendering
|
||||
bar = uiprogress.AddBar(int(progress.Total)) // Add a new bar
|
||||
|
||||
// display the total file size and how much has downloaded so far
|
||||
bar.PrependFunc(func(b *uiprogress.Bar) string {
|
||||
return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
|
||||
})
|
||||
|
||||
// display completion percentage
|
||||
bar.AppendFunc(func(b *uiprogress.Bar) string {
|
||||
return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
|
||||
})
|
||||
}
|
||||
bar.Set(int(progress.Completed))
|
||||
mutex.Unlock()
|
||||
}
|
||||
_, err = client.Pull(context.Background(), &pr, callback)
|
||||
return err
|
||||
|
|
4
go.mod
4
go.mod
|
@ -4,6 +4,7 @@ go 1.20
|
|||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/gosuri/uiprogress v0.0.1
|
||||
github.com/spf13/cobra v1.7.0
|
||||
)
|
||||
|
||||
|
@ -17,6 +18,7 @@ require (
|
|||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/gosuri/uilive v0.0.4 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
|
@ -32,7 +34,7 @@ require (
|
|||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.10.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.9.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/text v0.10.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
8
go.sum
8
go.sum
|
@ -28,6 +28,10 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY=
|
||||
github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI=
|
||||
github.com/gosuri/uiprogress v0.0.1 h1:0kpv/XY/qTmFWl/SkaJykZXrBBzwwadmW8fRb7RJSxw=
|
||||
github.com/gosuri/uiprogress v0.0.1/go.mod h1:C1RTYn4Sc7iEyf6j8ft5dyoZ4212h8G1ol9QQluh5+0=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
|
@ -97,8 +101,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
|
|
|
@ -9,15 +9,14 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
// const directoryURL = "https://ollama.ai/api/models"
|
||||
// TODO
|
||||
const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
|
||||
|
||||
type directoryCtxKey string
|
||||
|
||||
var dirCtx directoryCtxKey = "directory"
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
|
@ -31,7 +30,7 @@ type Model struct {
|
|||
License string `json:"license"`
|
||||
}
|
||||
|
||||
func pull(model string, progressCh chan<- string) error {
|
||||
func pull(model string, progressCh chan<- api.PullProgress) error {
|
||||
remote, err := getRemote(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pull model: %w", err)
|
||||
|
@ -64,7 +63,7 @@ func getRemote(model string) (*Model, error) {
|
|||
return nil, fmt.Errorf("model not found in directory: %s", model)
|
||||
}
|
||||
|
||||
func saveModel(model *Model, progressCh chan<- string) error {
|
||||
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||
// this models cache directory is created by the server on startup
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
@ -130,11 +129,18 @@ func saveModel(model *Model, progressCh chan<- string) error {
|
|||
totalBytes += n
|
||||
|
||||
// send progress updates
|
||||
progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
|
||||
progressCh <- api.PullProgress{
|
||||
Total: totalSize,
|
||||
Completed: totalBytes,
|
||||
Percent: float64(totalBytes) / float64(totalSize) * 100,
|
||||
}
|
||||
}
|
||||
|
||||
// send completion message
|
||||
progressCh <- "Download complete!"
|
||||
progressCh <- api.PullProgress{
|
||||
Total: totalSize,
|
||||
Completed: totalSize,
|
||||
Percent: 100,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ func Serve(ln net.Listener) error {
|
|||
return
|
||||
}
|
||||
|
||||
progressCh := make(chan string)
|
||||
progressCh := make(chan api.PullProgress)
|
||||
go func() {
|
||||
defer close(progressCh)
|
||||
if err := pull(req.Model, progressCh); err != nil {
|
||||
|
|
Loading…
Reference in a new issue