calculate and verify md5 checksum

This commit is contained in:
Michael Yang 2023-10-27 10:11:28 -07:00
parent 186f685224
commit 115fc56eb7

View file

@ -2,13 +2,16 @@ package server
import ( import (
"context" "context"
"crypto/md5"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -42,6 +45,7 @@ type blobUploadPart struct {
N int N int
Offset int64 Offset int64
Size int64 Size int64
hash.Hash
} }
const ( const (
@ -96,7 +100,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
} }
// set part.N to the current number of parts // set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size}) b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
offset += size offset += size
} }
@ -167,8 +171,16 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
requestURL := <-b.nextURL requestURL := <-b.nextURL
var sb strings.Builder
for _, part := range b.Parts {
sb.Write(part.Sum(nil))
}
md5sum := md5.Sum([]byte(sb.String()))
values := requestURL.Query() values := requestURL.Query()
values.Add("digest", b.Digest) values.Add("digest", b.Digest)
values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
requestURL.RawQuery = values.Encode() requestURL.RawQuery = values.Encode()
headers := make(http.Header) headers := make(http.Header)
@ -196,7 +208,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
} }
buw := blobUploadWriter{blobUpload: b} buw := blobUploadWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, &buw), opts) resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, io.MultiWriter(&buw, part.Hash)), opts)
if err != nil { if err != nil {
return err return err
} }
@ -225,6 +237,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
rs.Seek(0, io.SeekStart) rs.Seek(0, io.SeekStart)
b.Completed.Add(-buw.written) b.Completed.Add(-buw.written)
buw.written = 0 buw.written = 0
part.Hash = md5.New()
err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil) err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):