diff --git a/server/download.go b/server/download.go index fa559f0c..be3eda7c 100644 --- a/server/download.go +++ b/server/download.go @@ -134,7 +134,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { defer blobDownloadManager.Delete(b.Digest) - ctx, b.CancelFunc = context.WithCancel(ctx) file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644) @@ -170,7 +169,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis } } - return errors.New("max retries exceeded") + return errMaxRetriesExceeded }) } @@ -308,6 +307,8 @@ type downloadOpts struct { const maxRetries = 3 +var errMaxRetriesExceeded = errors.New("max retries exceeded") + // downloadBlob downloads a blob from the registry and stores it in the blobs directory func downloadBlob(ctx context.Context, opts downloadOpts) error { fp, err := GetBlobsPath(opts.digest) diff --git a/server/images.go b/server/images.go index e249f8f9..6f72f0cf 100644 --- a/server/images.go +++ b/server/images.go @@ -981,46 +981,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu layers = append(layers, &manifest.Config) for _, layer := range layers { - exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts) - if err != nil { - return err - } - - if exists { - fn(api.ProgressResponse{ - Status: "using existing layer", - Digest: layer.Digest, - Total: layer.Size, - Completed: layer.Size, - }) - log.Printf("Layer %s already exists", layer.Digest) - continue - } - - fn(api.ProgressResponse{ - Status: "starting upload", - Digest: layer.Digest, - Total: layer.Size, - }) - - location, chunkSize, err := startUpload(ctx, mp, layer, regOpts) - if err != nil { - log.Printf("couldn't start upload: %v", err) - return err - } - - if strings.HasPrefix(filepath.Base(location.Path), "sha256:") { - layer.Digest = filepath.Base(location.Path) - fn(api.ProgressResponse{ - Status: "using existing layer", - Digest: layer.Digest, - Total: layer.Size, - Completed: layer.Size, - }) - continue - } - - if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil { + if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { log.Printf("error uploading blob: %v", err) return err } @@ -1218,24 +1179,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } -// Function to check if a blob already exists in the Docker registry -func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) { - requestURL := mp.BaseURL() - requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest) - - resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts) - if err != nil { - log.Printf("couldn't check for blob: %v", err) - return false, err - } - defer resp.Body.Close() - - // Check for success: If the blob exists, the Docker registry will respond with a 200 OK - return resp.StatusCode < http.StatusBadRequest, nil -} - func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { - var status string for try := 0; try < maxRetries; try++ { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { @@ -1243,8 +1187,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR return nil, err } - status = resp.Status - switch { case resp.StatusCode == http.StatusUnauthorized: auth := resp.Header.Get("www-authenticate") @@ -1270,7 +1212,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR } } - return nil, fmt.Errorf("max retry exceeded: %v", status) + return nil, errMaxRetriesExceeded } func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { diff --git a/server/routes.go b/server/routes.go index 5c52dbfe..9dc3732b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -365,7 +365,9 @@ func PushModelHandler(c *gin.Context) { Insecure: req.Insecure, } - ctx := context.Background() + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() + if err := PushModel(ctx, req.Name, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } diff --git a/server/upload.go b/server/upload.go index 8f655337..ddf4321d 100644 --- a/server/upload.go +++ b/server/upload.go @@ -9,211 +9,344 @@ import ( "net/http" "net/url" "os" - "strconv" "sync" + "sync/atomic" + "time" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/format" + "golang.org/x/sync/errgroup" ) -const ( - redirectChunkSize int64 = 1024 * 1024 * 1024 - regularChunkSize int64 = 95 * 1024 * 1024 -) +var blobUploadManager sync.Map -func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) { - requestURL := mp.BaseURL() - requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") - if layer.From != "" { - values := requestURL.Query() - values.Add("mount", layer.Digest) - values.Add("from", layer.From) - requestURL.RawQuery = values.Encode() - } +type blobUpload struct { + *Layer - resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts) - if err != nil { - log.Printf("couldn't start upload: %v", err) - return nil, 0, err - } - defer resp.Body.Close() + Total int64 + Completed atomic.Int64 - location := resp.Header.Get("Docker-Upload-Location") - chunkSize := redirectChunkSize - if location == "" { - location = resp.Header.Get("Location") - chunkSize = regularChunkSize - } + Parts []blobUploadPart - locationURL, err := url.Parse(location) - if err != nil { - return nil, 0, err - } + nextURL chan *url.URL - return locationURL, chunkSize, nil + context.CancelFunc + + done bool + err error + references atomic.Int32 } -func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - // TODO allow resumability - // TODO allow canceling uploads via DELETE +type blobUploadPart struct { + // N is the part number + N int + Offset int64 + Size int64 +} - fp, err := GetBlobsPath(layer.Digest) +const ( + numUploadParts = 64 + minUploadPartSize int64 = 95 * 1000 * 1000 + maxUploadPartSize int64 = 1000 * 1000 * 1000 +) + +func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { + p, err := GetBlobsPath(b.Digest) if err != nil { return err } - f, err := os.Open(fp) + if b.From != "" { + values := requestURL.Query() + values.Add("mount", b.Digest) + values.Add("from", b.From) + requestURL.RawQuery = values.Encode() + } + + resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, opts) + if err != nil { + return err + } + defer resp.Body.Close() + + location := resp.Header.Get("Docker-Upload-Location") + if location == "" { + location = resp.Header.Get("Location") + } + + fi, err := os.Stat(p) + if err != nil { + return err + } + + b.Total = fi.Size() + + var size = b.Total / numUploadParts + switch { + case size < minUploadPartSize: + size = minUploadPartSize + case size > maxUploadPartSize: + size = maxUploadPartSize + } + + var offset int64 + for offset < fi.Size() { + if offset+size > fi.Size() { + size = fi.Size() - offset + } + + b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size}) + offset += size + } + + log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(size)) + + requestURL, err = url.Parse(location) + if err != nil { + return err + } + + b.nextURL = make(chan *url.URL, 1) + b.nextURL <- requestURL + return nil +} + +func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { + b.err = b.run(ctx, opts) +} + +func (b *blobUpload) run(ctx context.Context, opts *RegistryOptions) error { + defer blobUploadManager.Delete(b.Digest) + ctx, b.CancelFunc = context.WithCancel(ctx) + + p, err := GetBlobsPath(b.Digest) + if err != nil { + return err + } + + f, err := os.Open(p) if err != nil { return err } defer f.Close() - pw := ProgressWriter{ - status: fmt.Sprintf("uploading %s", layer.Digest), - digest: layer.Digest, - total: layer.Size, - fn: fn, + g, inner := errgroup.WithContext(ctx) + g.SetLimit(numUploadParts) + for i := range b.Parts { + part := &b.Parts[i] + requestURL := <-b.nextURL + g.Go(func() error { + for try := 0; try < maxRetries; try++ { + r := io.NewSectionReader(f, part.Offset, part.Size) + err := b.uploadChunk(inner, http.MethodPatch, requestURL, r, part, opts) + switch { + case errors.Is(err, context.Canceled): + return err + case errors.Is(err, errMaxRetriesExceeded): + return err + case err != nil: + log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err) + continue + } + + return nil + } + + return errMaxRetriesExceeded + }) } - for offset := int64(0); offset < layer.Size; { - chunk := layer.Size - offset - if chunk > chunkSize { - chunk = chunkSize - } - - resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw) - if err != nil { - fn(api.ProgressResponse{ - Status: fmt.Sprintf("error uploading chunk: %v", err), - Digest: layer.Digest, - Total: layer.Size, - Completed: offset, - }) - - return err - } - - offset += chunk - location := resp.Header.Get("Docker-Upload-Location") - if location == "" { - location = resp.Header.Get("Location") - } - - requestURL, err = url.Parse(location) - if err != nil { - return err - } + if err := g.Wait(); err != nil { + return err } + requestURL := <-b.nextURL + values := requestURL.Query() - values.Add("digest", layer.Digest) + values.Add("digest", b.Digest) requestURL.RawQuery = values.Encode() headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", "0") - // finish the upload - resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts) + resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, opts) if err != nil { - log.Printf("couldn't finish upload: %v", err) return err } defer resp.Body.Close() - if resp.StatusCode >= http.StatusBadRequest { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body)) - } + b.done = true return nil } -func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) { - sectionReader := io.NewSectionReader(r, offset, limit) - +func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, rs io.ReadSeeker, part *blobUploadPart, opts *RegistryOptions) error { headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") - headers.Set("Content-Length", strconv.Itoa(int(limit))) + headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) headers.Set("X-Redirect-Uploads", "1") if method == http.MethodPatch { - headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1)) + headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1)) } - for try := 0; try < maxRetries; try++ { - resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts) - if err != nil && !errors.Is(err, io.EOF) { - return nil, err + buw := blobUploadWriter{blobUpload: b} + resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, &buw), opts) + if err != nil { + return err + } + defer resp.Body.Close() + + location := resp.Header.Get("Docker-Upload-Location") + if location == "" { + location = resp.Header.Get("Location") + } + + nextURL, err := url.Parse(location) + if err != nil { + return err + } + + switch { + case resp.StatusCode == http.StatusTemporaryRedirect: + b.nextURL <- nextURL + + redirectURL, err := resp.Location() + if err != nil { + return err } - defer resp.Body.Close() - switch { - case resp.StatusCode == http.StatusTemporaryRedirect: - location, err := resp.Location() - if err != nil { - return nil, err - } - - pw.completed = offset - if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil { - // retry - log.Printf("retrying redirected upload: %v", err) + for try := 0; try < maxRetries; try++ { + rs.Seek(0, io.SeekStart) + b.Completed.Add(-buw.written) + err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil) + switch { + case errors.Is(err, context.Canceled): + return err + case errors.Is(err, errMaxRetriesExceeded): + return err + case err != nil: + log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err) continue } - return resp, nil - case resp.StatusCode == http.StatusUnauthorized: - auth := resp.Header.Get("www-authenticate") - authRedir := ParseAuthRedirectString(auth) - token, err := getAuthToken(ctx, authRedir) - if err != nil { - return nil, err - } - - opts.Token = token - - pw.completed = offset - sectionReader = io.NewSectionReader(r, offset, limit) - continue - case resp.StatusCode >= http.StatusBadRequest: - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) + return nil } - return resp, nil + return errMaxRetriesExceeded + + case resp.StatusCode == http.StatusUnauthorized: + auth := resp.Header.Get("www-authenticate") + authRedir := ParseAuthRedirectString(auth) + token, err := getAuthToken(ctx, authRedir) + if err != nil { + return err + } + + opts.Token = token + fallthrough + case resp.StatusCode >= http.StatusBadRequest: + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + rs.Seek(0, io.SeekStart) + b.Completed.Add(-buw.written) + return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body) } - return nil, fmt.Errorf("max retries exceeded") + if method == http.MethodPatch { + b.nextURL <- nextURL + } + + return nil } -type ProgressWriter struct { - status string - digest string - bucket int64 - completed int64 - total int64 - fn func(api.ProgressResponse) - mu sync.Mutex +func (b *blobUpload) acquire() { + b.references.Add(1) } -func (pw *ProgressWriter) Write(b []byte) (int, error) { - pw.mu.Lock() - defer pw.mu.Unlock() +func (b *blobUpload) release() { + if b.references.Add(-1) == 0 { + b.CancelFunc() + } +} - n := len(b) - pw.bucket += int64(n) +func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error { + b.acquire() + defer b.release() - // throttle status updates to not spam the client - if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total { - pw.completed += pw.bucket - pw.fn(api.ProgressResponse{ - Status: pw.status, - Digest: pw.digest, - Total: pw.total, - Completed: pw.completed, + ticker := time.NewTicker(60 * time.Millisecond) + for { + select { + case <-ticker.C: + case <-ctx.Done(): + return ctx.Err() + } + + fn(api.ProgressResponse{ + Status: fmt.Sprintf("uploading %s", b.Digest), + Digest: b.Digest, + Total: b.Total, + Completed: b.Completed.Load(), }) - pw.bucket = 0 + if b.done || b.err != nil { + return b.err + } } +} +type blobUploadWriter struct { + written int64 + *blobUpload +} + +func (b *blobUploadWriter) Write(p []byte) (n int, err error) { + n = len(p) + b.written += int64(n) + b.Completed.Add(int64(n)) return n, nil } + +func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error { + requestURL := mp.BaseURL() + requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) + + resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts) + if err != nil { + return err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusNotFound: + case http.StatusOK: + fn(api.ProgressResponse{ + Status: fmt.Sprintf("uploading %s", layer.Digest), + Digest: layer.Digest, + Total: layer.Size, + Completed: layer.Size, + }) + + return nil + default: + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer}) + upload := data.(*blobUpload) + if !ok { + requestURL := mp.BaseURL() + requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") + if err := upload.Prepare(ctx, requestURL, opts); err != nil { + blobUploadManager.Delete(layer.Digest) + return err + } + + go upload.Run(context.Background(), opts) + } + + return upload.Wait(ctx, fn) +}