resumable downloads
This commit is contained in:
parent
7ee75e19ec
commit
c9f45abef3
4 changed files with 42 additions and 17 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
|
@ -64,7 +65,14 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
|
|||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else {
|
||||
return err // Handle other errors
|
||||
}
|
||||
}
|
||||
if err := checkError(res, line); err != nil {
|
||||
return err
|
||||
}
|
||||
callback(bytes.TrimSuffix(line, []byte("\n")))
|
||||
}
|
||||
|
@ -128,8 +136,9 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
|
|||
return &res, nil
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
|
||||
var res PullResponse
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
|
||||
/*
|
||||
Events have the following format for progress:
|
||||
|
@ -148,10 +157,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progr
|
|||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
if progress.Completed >= progress.Total {
|
||||
wg.Done()
|
||||
}
|
||||
callback(progress)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -28,10 +28,6 @@ type PullProgress struct {
|
|||
Percent float64 `json:"percent"`
|
||||
}
|
||||
|
||||
type PullResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
|
|
|
@ -40,7 +40,7 @@ func run(model string) error {
|
|||
mutex := &sync.Mutex{}
|
||||
var progressData api.PullProgress
|
||||
|
||||
callback := func(progress api.PullProgress) {
|
||||
pullCallback := func(progress api.PullProgress) {
|
||||
mutex.Lock()
|
||||
progressData = progress
|
||||
if bar == nil {
|
||||
|
@ -60,8 +60,11 @@ func run(model string) error {
|
|||
bar.Set(int(progress.Completed))
|
||||
mutex.Unlock()
|
||||
}
|
||||
_, err = client.Pull(context.Background(), &pr, callback)
|
||||
if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Up to date.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func serve() error {
|
||||
|
|
|
@ -79,6 +79,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|||
panic(err)
|
||||
}
|
||||
// check for resume
|
||||
alreadyDownloaded := 0
|
||||
fileInfo, err := os.Stat(fileName)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
|
@ -86,7 +87,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|||
}
|
||||
// file doesn't exist, create it now
|
||||
} else {
|
||||
req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
|
||||
alreadyDownloaded = int(fileInfo.Size())
|
||||
req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
@ -96,7 +98,17 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
||||
// already downloaded
|
||||
progressCh <- api.PullProgress{
|
||||
Total: alreadyDownloaded,
|
||||
Completed: alreadyDownloaded,
|
||||
Percent: 100,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||
return fmt.Errorf("failed to download model: %s", resp.Status)
|
||||
}
|
||||
|
||||
|
@ -109,7 +121,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|||
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
totalBytes := 0
|
||||
totalBytes := alreadyDownloaded
|
||||
totalSize += alreadyDownloaded
|
||||
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
|
|
Loading…
Reference in a new issue