resumable downloads
This commit is contained in:
parent
7ee75e19ec
commit
c9f45abef3
4 changed files with 42 additions and 17 deletions
|
@ -9,11 +9,12 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
URL string
|
URL string
|
||||||
HTTP http.Client
|
HTTP http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkError(resp *http.Response, body []byte) error {
|
func checkError(resp *http.Response, body []byte) error {
|
||||||
|
@ -64,7 +65,14 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
return err // Handle other errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := checkError(res, line); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
callback(bytes.TrimSuffix(line, []byte("\n")))
|
callback(bytes.TrimSuffix(line, []byte("\n")))
|
||||||
}
|
}
|
||||||
|
@ -128,8 +136,9 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
|
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
|
||||||
var res PullResponse
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
|
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
|
||||||
/*
|
/*
|
||||||
Events have the following format for progress:
|
Events have the following format for progress:
|
||||||
|
@ -148,10 +157,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progr
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if progress.Completed >= progress.Total {
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
callback(progress)
|
callback(progress)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &res, nil
|
wg.Wait()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,10 +28,6 @@ type PullProgress struct {
|
||||||
Percent float64 `json:"percent"`
|
Percent float64 `json:"percent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PullResponse struct {
|
|
||||||
Response string `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
|
@ -40,7 +40,7 @@ func run(model string) error {
|
||||||
mutex := &sync.Mutex{}
|
mutex := &sync.Mutex{}
|
||||||
var progressData api.PullProgress
|
var progressData api.PullProgress
|
||||||
|
|
||||||
callback := func(progress api.PullProgress) {
|
pullCallback := func(progress api.PullProgress) {
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
progressData = progress
|
progressData = progress
|
||||||
if bar == nil {
|
if bar == nil {
|
||||||
|
@ -60,8 +60,11 @@ func run(model string) error {
|
||||||
bar.Set(int(progress.Completed))
|
bar.Set(int(progress.Completed))
|
||||||
mutex.Unlock()
|
mutex.Unlock()
|
||||||
}
|
}
|
||||||
_, err = client.Pull(context.Background(), &pr, callback)
|
if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("Up to date.")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func serve() error {
|
func serve() error {
|
||||||
|
|
|
@ -79,6 +79,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
// check for resume
|
// check for resume
|
||||||
|
alreadyDownloaded := 0
|
||||||
fileInfo, err := os.Stat(fileName)
|
fileInfo, err := os.Stat(fileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
|
@ -86,7 +87,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
}
|
}
|
||||||
// file doesn't exist, create it now
|
// file doesn't exist, create it now
|
||||||
} else {
|
} else {
|
||||||
req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
|
alreadyDownloaded = int(fileInfo.Size())
|
||||||
|
req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
@ -96,7 +98,17 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
||||||
|
// already downloaded
|
||||||
|
progressCh <- api.PullProgress{
|
||||||
|
Total: alreadyDownloaded,
|
||||||
|
Completed: alreadyDownloaded,
|
||||||
|
Percent: 100,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||||
return fmt.Errorf("failed to download model: %s", resp.Status)
|
return fmt.Errorf("failed to download model: %s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,7 +121,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
|
||||||
|
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
totalBytes := 0
|
totalBytes := alreadyDownloaded
|
||||||
|
totalSize += alreadyDownloaded
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := resp.Body.Read(buf)
|
n, err := resp.Body.Read(buf)
|
||||||
|
|
Loading…
Add table
Reference in a new issue