diff --git a/server/auth.go b/server/auth.go index f24c46af..edf7865b 100644 --- a/server/auth.go +++ b/server/auth.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "crypto/rand" "crypto/sha256" "encoding/base64" @@ -50,7 +51,7 @@ func (r AuthRedirect) URL() (string, error) { return fmt.Sprintf("%s?service=%s&scope=%s&ts=%d&nonce=%s", r.Realm, r.Service, r.Scope, time.Now().Unix(), nonce), nil } -func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, error) { +func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) { url, err := redirData.URL() if err != nil { return "", err @@ -92,7 +93,7 @@ func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, err "Authorization": sig, } - resp, err := makeRequest("GET", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts) if err != nil { log.Printf("couldn't get token: %q", err) } diff --git a/server/download.go b/server/download.go index 7aad599d..932e29a0 100644 --- a/server/download.go +++ b/server/download.go @@ -137,7 +137,7 @@ func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f * "Range": fmt.Sprintf("bytes=%d-", size), } - resp, err := makeRequest("GET", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts) if err != nil { log.Printf("couldn't download blob: %v", err) return err diff --git a/server/images.go b/server/images.go index a08cb8fb..2177fe6d 100644 --- a/server/images.go +++ b/server/images.go @@ -24,6 +24,8 @@ import ( "github.com/jmorganca/ollama/vector" ) +const MaxRetries = 3 + type RegistryOptions struct { Insecure bool Username string @@ -856,7 +858,7 @@ func DeleteModel(name string) error { return nil } -func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -872,7 +874,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon layers = append(layers, &manifest.Config) for _, layer := range layers { - exists, err := checkBlobExistence(mp, layer.Digest, regOpts) + exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts) if err != nil { return err } @@ -894,13 +896,13 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon Total: layer.Size, }) - location, err := startUpload(mp, regOpts) + location, err := startUpload(ctx, mp, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) return err } - err = uploadBlobChunked(mp, location, layer, regOpts, fn) + err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn) if err != nil { log.Printf("error uploading blob: %v", err) return err @@ -918,7 +920,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon return err } - resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts) + resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts) if err != nil { return err } @@ -940,7 +942,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu fn(api.ProgressResponse{Status: "pulling manifest"}) - manifest, err := pullModelManifest(mp, regOpts) + manifest, err := pullModelManifest(ctx, mp, regOpts) if err != nil { return fmt.Errorf("pull model manifest: %s", err) } @@ -996,13 +998,13 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { +func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag) headers := map[string]string{ "Accept": "application/vnd.docker.distribution.manifest.v2+json", } - resp, err := makeRequest("GET", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts) if err != nil { log.Printf("couldn't get manifest: %v", err) return nil, err @@ -1061,10 +1063,10 @@ func GetSHA256Digest(r io.Reader) (string, int) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) } -func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) { +func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) { url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) - resp, err := makeRequest("POST", url, nil, nil, regOpts) + resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) return "", err @@ -1087,10 +1089,10 @@ func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) { } // Function to check if a blob already exists in the Docker registry -func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) { +func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) { url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest) - resp, err := makeRequest("HEAD", url, nil, nil, regOpts) + resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts) if err != nil { log.Printf("couldn't check for blob: %v", err) return false, err @@ -1101,7 +1103,7 @@ func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) ( return resp.StatusCode == http.StatusOK, nil } -func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { // TODO allow resumability // TODO allow canceling uploads via DELETE // TODO allow cross repo blob mount @@ -1158,7 +1160,7 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry headers["Content-Length"] = strconv.Itoa(int(layer.Size)) // finish the upload - resp, err := makeRequest("PUT", url, headers, r, regOpts) + resp, err := makeRequest(ctx, "PUT", url, headers, r, regOpts) if err != nil { log.Printf("couldn't finish upload: %v", err) return err @@ -1172,7 +1174,16 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry return nil } -func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { +func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { + retryCtx := ctx.Value("retries") + var retries int + var ok bool + if retries, ok = retryCtx.(int); ok { + if retries > MaxRetries { + return nil, fmt.Errorf("Maximum retries hit; are you sure you have access to this resource?") + } + } + if !strings.HasPrefix(url, "http") { if regOpts.Insecure { url = "http://" + url @@ -1225,13 +1236,14 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader, if resp.StatusCode == http.StatusUnauthorized { auth := resp.Header.Get("Www-Authenticate") authRedir := ParseAuthRedirectString(string(auth)) - token, err := getAuthToken(authRedir, regOpts) + token, err := getAuthToken(ctx, authRedir, regOpts) if err != nil { return nil, err } regOpts.Token = token bodyCopy = bytes.NewReader(buf.Bytes()) - return makeRequest(method, url, headers, bodyCopy, regOpts) + ctx = context.WithValue(ctx, "retries", retries+1) + return makeRequest(ctx, method, url, headers, bodyCopy, regOpts) } return resp, nil diff --git a/server/routes.go b/server/routes.go index fd9214d1..3e13328b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -277,7 +277,8 @@ func PushModelHandler(c *gin.Context) { Password: req.Password, } - if err := PushModel(req.Name, regOpts, fn); err != nil { + ctx := context.Background() + if err := PushModel(ctx, req.Name, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()