From c9f45abef3c8d87c0258073ac64aeca95d117601 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 6 Jul 2023 14:57:11 -0400 Subject: [PATCH] resumable downloads --- api/client.go | 27 ++++++++++++++++++++------- api/types.go | 4 ---- cmd/cmd.go | 9 ++++++--- server/models.go | 19 ++++++++++++++++--- 4 files changed, 42 insertions(+), 17 deletions(-) diff --git a/api/client.go b/api/client.go index f3b2ac80..b13e41ef 100644 --- a/api/client.go +++ b/api/client.go @@ -9,11 +9,12 @@ import ( "io" "net/http" "strings" + "sync" ) type Client struct { - URL string - HTTP http.Client + URL string + HTTP http.Client } func checkError(resp *http.Response, body []byte) error { @@ -64,7 +65,14 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData for { line, err := reader.ReadBytes('\n') if err != nil { - break + if err == io.EOF { + break + } else { + return err // Handle other errors + } + } + if err := checkError(res, line); err != nil { + return err } callback(bytes.TrimSuffix(line, []byte("\n"))) } @@ -128,8 +136,9 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu return &res, nil } -func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) { - var res PullResponse +func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error { + var wg sync.WaitGroup + wg.Add(1) if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) { /* Events have the following format for progress: @@ -148,10 +157,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progr fmt.Println(err) return } + if progress.Completed >= progress.Total { + wg.Done() + } callback(progress) }); err != nil { - return nil, err + return err } - return &res, nil + wg.Wait() + return nil } diff --git a/api/types.go b/api/types.go index 5ab4ba33..de1dd107 100644 --- a/api/types.go +++ b/api/types.go @@ -28,10 +28,6 @@ type PullProgress struct { Percent float64 `json:"percent"` } -type PullResponse struct { - Response string `json:"response"` -} - type GenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 36abd667..b861dcfb 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -40,7 +40,7 @@ func run(model string) error { mutex := &sync.Mutex{} var progressData api.PullProgress - callback := func(progress api.PullProgress) { + pullCallback := func(progress api.PullProgress) { mutex.Lock() progressData = progress if bar == nil { @@ -60,8 +60,11 @@ func run(model string) error { bar.Set(int(progress.Completed)) mutex.Unlock() } - _, err = client.Pull(context.Background(), &pr, callback) - return err + if err := client.Pull(context.Background(), &pr, pullCallback); err != nil { + return err + } + fmt.Println("Up to date.") + return nil } func serve() error { diff --git a/server/models.go b/server/models.go index cfa04002..a64a60d4 100644 --- a/server/models.go +++ b/server/models.go @@ -79,6 +79,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { panic(err) } // check for resume + alreadyDownloaded := 0 fileInfo, err := os.Stat(fileName) if err != nil { if !os.IsNotExist(err) { @@ -86,7 +87,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { } // file doesn't exist, create it now } else { - req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-") + alreadyDownloaded = int(fileInfo.Size()) + req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-") } resp, err := client.Do(req) @@ -96,7 +98,17 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { + // already downloaded + progressCh <- api.PullProgress{ + Total: alreadyDownloaded, + Completed: alreadyDownloaded, + Percent: 100, + } + return nil + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { return fmt.Errorf("failed to download model: %s", resp.Status) } @@ -109,7 +121,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length")) buf := make([]byte, 1024) - totalBytes := 0 + totalBytes := alreadyDownloaded + totalSize += alreadyDownloaded for { n, err := resp.Body.Read(buf)