resumable downloads

This commit is contained in:
Bruce MacDonald 2023-07-06 14:57:11 -04:00 committed by Jeffrey Morgan
parent 7ee75e19ec
commit c9f45abef3
4 changed files with 42 additions and 17 deletions

View file

@ -9,11 +9,12 @@ import (
"io"
"net/http"
"strings"
"sync"
)
type Client struct {
URL string
HTTP http.Client
URL string
HTTP http.Client
}
func checkError(resp *http.Response, body []byte) error {
@ -64,7 +65,14 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
for {
line, err := reader.ReadBytes('\n')
if err != nil {
break
if err == io.EOF {
break
} else {
return err // Handle other errors
}
}
if err := checkError(res, line); err != nil {
return err
}
callback(bytes.TrimSuffix(line, []byte("\n")))
}
@ -128,8 +136,9 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
return &res, nil
}
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
var res PullResponse
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
var wg sync.WaitGroup
wg.Add(1)
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
/*
Events have the following format for progress:
@ -148,10 +157,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progr
fmt.Println(err)
return
}
if progress.Completed >= progress.Total {
wg.Done()
}
callback(progress)
}); err != nil {
return nil, err
return err
}
return &res, nil
wg.Wait()
return nil
}

View file

@ -28,10 +28,6 @@ type PullProgress struct {
Percent float64 `json:"percent"`
}
type PullResponse struct {
Response string `json:"response"`
}
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`

View file

@ -40,7 +40,7 @@ func run(model string) error {
mutex := &sync.Mutex{}
var progressData api.PullProgress
callback := func(progress api.PullProgress) {
pullCallback := func(progress api.PullProgress) {
mutex.Lock()
progressData = progress
if bar == nil {
@ -60,8 +60,11 @@ func run(model string) error {
bar.Set(int(progress.Completed))
mutex.Unlock()
}
_, err = client.Pull(context.Background(), &pr, callback)
return err
if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
return err
}
fmt.Println("Up to date.")
return nil
}
func serve() error {

View file

@ -79,6 +79,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
panic(err)
}
// check for resume
alreadyDownloaded := 0
fileInfo, err := os.Stat(fileName)
if err != nil {
if !os.IsNotExist(err) {
@ -86,7 +87,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
}
// file doesn't exist, create it now
} else {
req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
alreadyDownloaded = int(fileInfo.Size())
req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-")
}
resp, err := client.Do(req)
@ -96,7 +98,17 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
// already downloaded
progressCh <- api.PullProgress{
Total: alreadyDownloaded,
Completed: alreadyDownloaded,
Percent: 100,
}
return nil
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return fmt.Errorf("failed to download model: %s", resp.Status)
}
@ -109,7 +121,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
buf := make([]byte, 1024)
totalBytes := 0
totalBytes := alreadyDownloaded
totalSize += alreadyDownloaded
for {
n, err := resp.Body.Read(buf)