pr comments

This commit is contained in:
Bruce MacDonald 2023-08-01 15:34:52 -04:00
parent 907e6c56b3
commit 8228d166ce

View file

@ -55,11 +55,7 @@ func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *Reg
// this is another client requesting the server to download the same blob concurrently
return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
}
resp, err := requestDownload(ctx, mp, regOpts, fileDownload)
if err != nil {
return err
}
return doDownload(ctx, fileDownload, resp, fn)
return doDownload(ctx, mp, regOpts, fileDownload, fn)
}
var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
@ -68,34 +64,28 @@ var downloadMu sync.Mutex // mutex to check to resume a download while monitorin
func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
tick := time.NewTicker(time.Second)
for range tick.C {
done, resume, err := func() (bool, bool, error) {
downloadMu.Lock()
defer downloadMu.Unlock()
val, downloading := inProgress.Load(f.Digest)
if !downloading {
// check once again if the download is complete
if fi, _ := os.Stat(f.FilePath); fi != nil {
downloadMu.Unlock()
// successfull download while monitoring
// successful download while monitoring
fn(api.ProgressResponse{
Digest: f.Digest,
Total: int(fi.Size()),
Completed: int(fi.Size()),
})
return nil
return true, false, nil
}
// resume the download
resp, err := requestDownload(ctx, mp, regOpts, f)
if err != nil {
downloadMu.Unlock()
return fmt.Errorf("resume: %w", err)
inProgress.Store(f.Digest, f) // store the file download again to claim the resume
return false, true, nil
}
inProgress.Store(f.Digest, f)
downloadMu.Unlock()
return doDownload(ctx, f, resp, fn)
}
downloadMu.Unlock()
f, ok := val.(*FileDownload)
if !ok {
return fmt.Errorf("invalid type for in progress download: %T", val)
return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
}
fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", f.Digest),
@ -103,14 +93,26 @@ func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
Total: int(f.Total),
Completed: int(f.Completed),
})
return false, false, nil
}()
if err != nil {
return err
}
if done {
// done downloading
return nil
}
if resume {
return doDownload(ctx, mp, regOpts, f, fn)
}
}
return nil
}
var chunkSize = 1024 * 1024 // 1 MiB in bytes
// requestDownload requests a blob from the registry and returns the response
func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) {
// doDownload downloads a blob from the registry and stores it in the blobs directory
func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
var size int64
fi, err := os.Stat(f.FilePath + "-partial")
@ -118,7 +120,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return nil, fmt.Errorf("stat: %w", err)
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
// Ensure the size is divisible by the chunk size by removing excess bytes
@ -126,7 +128,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
err := os.Truncate(f.FilePath+"-partial", size)
if err != nil {
return nil, fmt.Errorf("truncate: %w", err)
return fmt.Errorf("truncate: %w", err)
}
}
@ -138,18 +140,18 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
resp, err := makeRequest("GET", url, headers, nil, regOpts)
if err != nil {
log.Printf("couldn't download blob: %v", err)
return nil, err
return err
}
// resp MUST be closed by doDownload, which should follow this function
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
}
err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
if err != nil {
return nil, fmt.Errorf("make blobs directory: %w", err)
return fmt.Errorf("make blobs directory: %w", err)
}
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
@ -157,12 +159,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
f.Total = remaining + f.Completed
inProgress.Store(f.Digest, f)
return resp, nil
}
// doDownload downloads a blob from the registry and stores it in the blobs directory
func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error {
defer resp.Body.Close()
out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open file: %w", err)