resumable downloads

This commit is contained in:
Bruce MacDonald 2023-07-06 14:57:11 -04:00 committed by Jeffrey Morgan
parent 7ee75e19ec
commit c9f45abef3
4 changed files with 42 additions and 17 deletions

View file

@ -9,11 +9,12 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
) )
type Client struct { type Client struct {
URL string URL string
HTTP http.Client HTTP http.Client
} }
func checkError(resp *http.Response, body []byte) error { func checkError(resp *http.Response, body []byte) error {
@ -64,7 +65,14 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
for { for {
line, err := reader.ReadBytes('\n') line, err := reader.ReadBytes('\n')
if err != nil { if err != nil {
break if err == io.EOF {
break
} else {
return err // Handle other errors
}
}
if err := checkError(res, line); err != nil {
return err
} }
callback(bytes.TrimSuffix(line, []byte("\n"))) callback(bytes.TrimSuffix(line, []byte("\n")))
} }
@ -128,8 +136,9 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
return &res, nil return &res, nil
} }
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) { func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
var res PullResponse var wg sync.WaitGroup
wg.Add(1)
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) { if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
/* /*
Events have the following format for progress: Events have the following format for progress:
@ -148,10 +157,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progr
fmt.Println(err) fmt.Println(err)
return return
} }
if progress.Completed >= progress.Total {
wg.Done()
}
callback(progress) callback(progress)
}); err != nil { }); err != nil {
return nil, err return err
} }
return &res, nil wg.Wait()
return nil
} }

View file

@ -28,10 +28,6 @@ type PullProgress struct {
Percent float64 `json:"percent"` Percent float64 `json:"percent"`
} }
type PullResponse struct {
Response string `json:"response"`
}
type GenerateRequest struct { type GenerateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`

View file

@ -40,7 +40,7 @@ func run(model string) error {
mutex := &sync.Mutex{} mutex := &sync.Mutex{}
var progressData api.PullProgress var progressData api.PullProgress
callback := func(progress api.PullProgress) { pullCallback := func(progress api.PullProgress) {
mutex.Lock() mutex.Lock()
progressData = progress progressData = progress
if bar == nil { if bar == nil {
@ -60,8 +60,11 @@ func run(model string) error {
bar.Set(int(progress.Completed)) bar.Set(int(progress.Completed))
mutex.Unlock() mutex.Unlock()
} }
_, err = client.Pull(context.Background(), &pr, callback) if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
return err return err
}
fmt.Println("Up to date.")
return nil
} }
func serve() error { func serve() error {

View file

@ -79,6 +79,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
panic(err) panic(err)
} }
// check for resume // check for resume
alreadyDownloaded := 0
fileInfo, err := os.Stat(fileName) fileInfo, err := os.Stat(fileName)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
@ -86,7 +87,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
} }
// file doesn't exist, create it now // file doesn't exist, create it now
} else { } else {
req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-") alreadyDownloaded = int(fileInfo.Size())
req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-")
} }
resp, err := client.Do(req) resp, err := client.Do(req)
@ -96,7 +98,17 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
// already downloaded
progressCh <- api.PullProgress{
Total: alreadyDownloaded,
Completed: alreadyDownloaded,
Percent: 100,
}
return nil
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return fmt.Errorf("failed to download model: %s", resp.Status) return fmt.Errorf("failed to download model: %s", resp.Status)
} }
@ -109,7 +121,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length")) totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
buf := make([]byte, 1024) buf := make([]byte, 1024)
totalBytes := 0 totalBytes := alreadyDownloaded
totalSize += alreadyDownloaded
for { for {
n, err := resp.Body.Read(buf) n, err := resp.Body.Read(buf)