pr comments

This commit is contained in:
Bruce MacDonald 2023-08-01 15:34:52 -04:00
parent 907e6c56b3
commit 8228d166ce

View file

@ -55,11 +55,7 @@ func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *Reg
// this is another client requesting the server to download the same blob concurrently // this is another client requesting the server to download the same blob concurrently
return monitorDownload(ctx, mp, regOpts, fileDownload, fn) return monitorDownload(ctx, mp, regOpts, fileDownload, fn)
} }
resp, err := requestDownload(ctx, mp, regOpts, fileDownload) return doDownload(ctx, mp, regOpts, fileDownload, fn)
if err != nil {
return err
}
return doDownload(ctx, fileDownload, resp, fn)
} }
var downloadMu sync.Mutex // mutex to check to resume a download while monitoring var downloadMu sync.Mutex // mutex to check to resume a download while monitoring
@ -68,49 +64,55 @@ var downloadMu sync.Mutex // mutex to check to resume a download while monitorin
func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
tick := time.NewTicker(time.Second) tick := time.NewTicker(time.Second)
for range tick.C { for range tick.C {
downloadMu.Lock() done, resume, err := func() (bool, bool, error) {
val, downloading := inProgress.Load(f.Digest) downloadMu.Lock()
if !downloading { defer downloadMu.Unlock()
// check once again if the download is complete val, downloading := inProgress.Load(f.Digest)
if fi, _ := os.Stat(f.FilePath); fi != nil { if !downloading {
downloadMu.Unlock() // check once again if the download is complete
// successfull download while monitoring if fi, _ := os.Stat(f.FilePath); fi != nil {
fn(api.ProgressResponse{ // successful download while monitoring
Digest: f.Digest, fn(api.ProgressResponse{
Total: int(fi.Size()), Digest: f.Digest,
Completed: int(fi.Size()), Total: int(fi.Size()),
}) Completed: int(fi.Size()),
return nil })
return true, false, nil
}
// resume the download
inProgress.Store(f.Digest, f) // store the file download again to claim the resume
return false, true, nil
} }
// resume the download f, ok := val.(*FileDownload)
resp, err := requestDownload(ctx, mp, regOpts, f) if !ok {
if err != nil { return false, false, fmt.Errorf("invalid type for in progress download: %T", val)
downloadMu.Unlock()
return fmt.Errorf("resume: %w", err)
} }
inProgress.Store(f.Digest, f) fn(api.ProgressResponse{
downloadMu.Unlock() Status: fmt.Sprintf("downloading %s", f.Digest),
return doDownload(ctx, f, resp, fn) Digest: f.Digest,
Total: int(f.Total),
Completed: int(f.Completed),
})
return false, false, nil
}()
if err != nil {
return err
} }
downloadMu.Unlock() if done {
f, ok := val.(*FileDownload) // done downloading
if !ok { return nil
return fmt.Errorf("invalid type for in progress download: %T", val) }
if resume {
return doDownload(ctx, mp, regOpts, f, fn)
} }
fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", f.Digest),
Digest: f.Digest,
Total: int(f.Total),
Completed: int(f.Completed),
})
} }
return nil return nil
} }
var chunkSize = 1024 * 1024 // 1 MiB in bytes var chunkSize = 1024 * 1024 // 1 MiB in bytes
// requestDownload requests a blob from the registry and returns the response // doDownload downloads a blob from the registry and stores it in the blobs directory
func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) { func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error {
var size int64 var size int64
fi, err := os.Stat(f.FilePath + "-partial") fi, err := os.Stat(f.FilePath + "-partial")
@ -118,7 +120,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it // noop, file doesn't exist so create it
case err != nil: case err != nil:
return nil, fmt.Errorf("stat: %w", err) return fmt.Errorf("stat: %w", err)
default: default:
size = fi.Size() size = fi.Size()
// Ensure the size is divisible by the chunk size by removing excess bytes // Ensure the size is divisible by the chunk size by removing excess bytes
@ -126,7 +128,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
err := os.Truncate(f.FilePath+"-partial", size) err := os.Truncate(f.FilePath+"-partial", size)
if err != nil { if err != nil {
return nil, fmt.Errorf("truncate: %w", err) return fmt.Errorf("truncate: %w", err)
} }
} }
@ -138,18 +140,18 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
resp, err := makeRequest("GET", url, headers, nil, regOpts) resp, err := makeRequest("GET", url, headers, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't download blob: %v", err) log.Printf("couldn't download blob: %v", err)
return nil, err return err
} }
// resp MUST be closed by doDownload, which should follow this function defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body))
} }
err = os.MkdirAll(path.Dir(f.FilePath), 0o700) err = os.MkdirAll(path.Dir(f.FilePath), 0o700)
if err != nil { if err != nil {
return nil, fmt.Errorf("make blobs directory: %w", err) return fmt.Errorf("make blobs directory: %w", err)
} }
remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
@ -157,12 +159,7 @@ func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions
f.Total = remaining + f.Completed f.Total = remaining + f.Completed
inProgress.Store(f.Digest, f) inProgress.Store(f.Digest, f)
return resp, nil
}
// doDownload downloads a blob from the registry and stores it in the blobs directory
func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error {
defer resp.Body.Close()
out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil { if err != nil {
return fmt.Errorf("open file: %w", err) return fmt.Errorf("open file: %w", err)