pr comments
This commit is contained in:
parent
907e6c56b3
commit
8228d166ce
1 changed files with 46 additions and 49 deletions
|
@ -55,11 +55,7 @@ func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *Reg
|
||||||
// this is another client requesting the server to download the same blob concurrently
|
// this is another client requesting the server to download the same blob concurrently
|
||||||
return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
|
return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
|
||||||
}
|
}
|
||||||
resp, err := requestDownload(ctx, mp, regOpts, fileDownload)
|
return doDownload(ctx, mp, regOpts, fileDownload, fn)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return doDownload(ctx, fileDownload, resp, fn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
|
var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
|
||||||
|
@ -68,34 +64,28 @@ var downloadMu sync.Mutex // mutex to check to resume a download while monitorin
|
||||||
func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
|
func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
|
||||||
tick := time.NewTicker(time.Second)
|
tick := time.NewTicker(time.Second)
|
||||||
for range tick.C {
|
for range tick.C {
|
||||||
|
done, resume, err := func() (bool, bool, error) {
|
||||||
downloadMu.Lock()
|
downloadMu.Lock()
|
||||||
|
defer downloadMu.Unlock()
|
||||||
val, downloading := inProgress.Load(f.Digest)
|
val, downloading := inProgress.Load(f.Digest)
|
||||||
if !downloading {
|
if !downloading {
|
||||||
// check once again if the download is complete
|
// check once again if the download is complete
|
||||||
if fi, _ := os.Stat(f.FilePath); fi != nil {
|
if fi, _ := os.Stat(f.FilePath); fi != nil {
|
||||||
downloadMu.Unlock()
|
// successful download while monitoring
|
||||||
// successfull download while monitoring
|
|
||||||
fn(api.ProgressResponse{
|
fn(api.ProgressResponse{
|
||||||
Digest: f.Digest,
|
Digest: f.Digest,
|
||||||
Total: int(fi.Size()),
|
Total: int(fi.Size()),
|
||||||
Completed: int(fi.Size()),
|
Completed: int(fi.Size()),
|
||||||
})
|
})
|
||||||
return nil
|
return true, false, nil
|
||||||
}
|
}
|
||||||
// resume the download
|
// resume the download
|
||||||
resp, err := requestDownload(ctx, mp, regOpts, f)
|
inProgress.Store(f.Digest, f) // store the file download again to claim the resume
|
||||||
if err != nil {
|
return false, true, nil
|
||||||
downloadMu.Unlock()
|
|
||||||
return fmt.Errorf("resume: %w", err)
|
|
||||||
}
|
}
|
||||||
inProgress.Store(f.Digest, f)
|
|
||||||
downloadMu.Unlock()
|
|
||||||
return doDownload(ctx, f, resp, fn)
|
|
||||||
}
|
|
||||||
downloadMu.Unlock()
|
|
||||||
f, ok := val.(*FileDownload)
|
f, ok := val.(*FileDownload)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid type for in progress download: %T", val)
|
return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
|
||||||
}
|
}
|
||||||
fn(api.ProgressResponse{
|
fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("downloading %s", f.Digest),
|
Status: fmt.Sprintf("downloading %s", f.Digest),
|
||||||
|
@ -103,14 +93,26 @@ func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
|
||||||
Total: int(f.Total),
|
Total: int(f.Total),
|
||||||
Completed: int(f.Completed),
|
Completed: int(f.Completed),
|
||||||
})
|
})
|
||||||
|
return false, false, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
// done downloading
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if resume {
|
||||||
|
return doDownload(ctx, mp, regOpts, f, fn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var chunkSize = 1024 * 1024 // 1 MiB in bytes
|
var chunkSize = 1024 * 1024 // 1 MiB in bytes
|
||||||
|
|
||||||
// requestDownload requests a blob from the registry and returns the response
|
// doDownload downloads a blob from the registry and stores it in the blobs directory
|
||||||
func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) {
|
func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
|
||||||
var size int64
|
var size int64
|
||||||
|
|
||||||
fi, err := os.Stat(f.FilePath + "-partial")
|
fi, err := os.Stat(f.FilePath + "-partial")
|
||||||
|
@ -118,7 +120,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
// noop, file doesn't exist so create it
|
// noop, file doesn't exist so create it
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return nil, fmt.Errorf("stat: %w", err)
|
return fmt.Errorf("stat: %w", err)
|
||||||
default:
|
default:
|
||||||
size = fi.Size()
|
size = fi.Size()
|
||||||
// Ensure the size is divisible by the chunk size by removing excess bytes
|
// Ensure the size is divisible by the chunk size by removing excess bytes
|
||||||
|
@ -126,7 +128,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
|
||||||
|
|
||||||
err := os.Truncate(f.FilePath+"-partial", size)
|
err := os.Truncate(f.FilePath+"-partial", size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("truncate: %w", err)
|
return fmt.Errorf("truncate: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,18 +140,18 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
|
||||||
resp, err := makeRequest("GET", url, headers, nil, regOpts)
|
resp, err := makeRequest("GET", url, headers, nil, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("couldn't download blob: %v", err)
|
log.Printf("couldn't download blob: %v", err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
// resp MUST be closed by doDownload, which should follow this function
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
|
return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
|
err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("make blobs directory: %w", err)
|
return fmt.Errorf("make blobs directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||||
|
@ -157,12 +159,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
|
||||||
f.Total = remaining + f.Completed
|
f.Total = remaining + f.Completed
|
||||||
|
|
||||||
inProgress.Store(f.Digest, f)
|
inProgress.Store(f.Digest, f)
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// doDownload downloads a blob from the registry and stores it in the blobs directory
|
|
||||||
func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("open file: %w", err)
|
return fmt.Errorf("open file: %w", err)
|
||||||
|
|
Loading…
Reference in a new issue