add maximum retries when pushing (#334)

This commit is contained in:
Patrick Devine 2023-08-11 15:41:55 -07:00 committed by GitHub
parent 1556162c90
commit d9cf18e28d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 21 deletions

View file

@ -2,6 +2,7 @@ package server
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
@ -50,7 +51,7 @@ func (r AuthRedirect) URL() (string, error) {
return fmt.Sprintf("%s?service=%s&scope=%s&ts=%d&nonce=%s", r.Realm, r.Service, r.Scope, time.Now().Unix(), nonce), nil return fmt.Sprintf("%s?service=%s&scope=%s&ts=%d&nonce=%s", r.Realm, r.Service, r.Scope, time.Now().Unix(), nonce), nil
} }
func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, error) { func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
url, err := redirData.URL() url, err := redirData.URL()
if err != nil { if err != nil {
return "", err return "", err
@ -92,7 +93,7 @@ func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, err
"Authorization": sig, "Authorization": sig,
} }
resp, err := makeRequest("GET", url, headers, nil, regOpts) resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't get token: %q", err) log.Printf("couldn't get token: %q", err)
} }

View file

@ -137,7 +137,7 @@ func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *
"Range": fmt.Sprintf("bytes=%d-", size), "Range": fmt.Sprintf("bytes=%d-", size),
} }
resp, err := makeRequest("GET", url, headers, nil, regOpts) resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't download blob: %v", err) log.Printf("couldn't download blob: %v", err)
return err return err

View file

@ -24,6 +24,8 @@ import (
"github.com/jmorganca/ollama/vector" "github.com/jmorganca/ollama/vector"
) )
const MaxRetries = 3
type RegistryOptions struct { type RegistryOptions struct {
Insecure bool Insecure bool
Username string Username string
@ -856,7 +858,7 @@ func DeleteModel(name string) error {
return nil return nil
} }
func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
@ -872,7 +874,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
layers = append(layers, &manifest.Config) layers = append(layers, &manifest.Config)
for _, layer := range layers { for _, layer := range layers {
exists, err := checkBlobExistence(mp, layer.Digest, regOpts) exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
if err != nil { if err != nil {
return err return err
} }
@ -894,13 +896,13 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
Total: layer.Size, Total: layer.Size,
}) })
location, err := startUpload(mp, regOpts) location, err := startUpload(ctx, mp, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err) log.Printf("couldn't start upload: %v", err)
return err return err
} }
err = uploadBlobChunked(mp, location, layer, regOpts, fn) err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn)
if err != nil { if err != nil {
log.Printf("error uploading blob: %v", err) log.Printf("error uploading blob: %v", err)
return err return err
@ -918,7 +920,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
return err return err
} }
resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts) resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil { if err != nil {
return err return err
} }
@ -940,7 +942,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
fn(api.ProgressResponse{Status: "pulling manifest"}) fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err := pullModelManifest(mp, regOpts) manifest, err := pullModelManifest(ctx, mp, regOpts)
if err != nil { if err != nil {
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %s", err)
} }
@ -996,13 +998,13 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil return nil
} }
func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag) url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
headers := map[string]string{ headers := map[string]string{
"Accept": "application/vnd.docker.distribution.manifest.v2+json", "Accept": "application/vnd.docker.distribution.manifest.v2+json",
} }
resp, err := makeRequest("GET", url, headers, nil, regOpts) resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't get manifest: %v", err) log.Printf("couldn't get manifest: %v", err)
return nil, err return nil, err
@ -1061,10 +1063,10 @@ func GetSHA256Digest(r io.Reader) (string, int) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
} }
func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) { func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
resp, err := makeRequest("POST", url, nil, nil, regOpts) resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err) log.Printf("couldn't start upload: %v", err)
return "", err return "", err
@ -1087,10 +1089,10 @@ func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
} }
// Function to check if a blob already exists in the Docker registry // Function to check if a blob already exists in the Docker registry
func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) { func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest) url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
resp, err := makeRequest("HEAD", url, nil, nil, regOpts) resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't check for blob: %v", err) log.Printf("couldn't check for blob: %v", err)
return false, err return false, err
@ -1101,7 +1103,7 @@ func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (
return resp.StatusCode == http.StatusOK, nil return resp.StatusCode == http.StatusOK, nil
} }
func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
// TODO allow resumability // TODO allow resumability
// TODO allow canceling uploads via DELETE // TODO allow canceling uploads via DELETE
// TODO allow cross repo blob mount // TODO allow cross repo blob mount
@ -1158,7 +1160,7 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry
headers["Content-Length"] = strconv.Itoa(int(layer.Size)) headers["Content-Length"] = strconv.Itoa(int(layer.Size))
// finish the upload // finish the upload
resp, err := makeRequest("PUT", url, headers, r, regOpts) resp, err := makeRequest(ctx, "PUT", url, headers, r, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't finish upload: %v", err) log.Printf("couldn't finish upload: %v", err)
return err return err
@ -1172,7 +1174,16 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry
return nil return nil
} }
func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
retryCtx := ctx.Value("retries")
var retries int
var ok bool
if retries, ok = retryCtx.(int); ok {
if retries > MaxRetries {
return nil, fmt.Errorf("Maximum retries hit; are you sure you have access to this resource?")
}
}
if !strings.HasPrefix(url, "http") { if !strings.HasPrefix(url, "http") {
if regOpts.Insecure { if regOpts.Insecure {
url = "http://" + url url = "http://" + url
@ -1225,13 +1236,14 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
if resp.StatusCode == http.StatusUnauthorized { if resp.StatusCode == http.StatusUnauthorized {
auth := resp.Header.Get("Www-Authenticate") auth := resp.Header.Get("Www-Authenticate")
authRedir := ParseAuthRedirectString(string(auth)) authRedir := ParseAuthRedirectString(string(auth))
token, err := getAuthToken(authRedir, regOpts) token, err := getAuthToken(ctx, authRedir, regOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
regOpts.Token = token regOpts.Token = token
bodyCopy = bytes.NewReader(buf.Bytes()) bodyCopy = bytes.NewReader(buf.Bytes())
return makeRequest(method, url, headers, bodyCopy, regOpts) ctx = context.WithValue(ctx, "retries", retries+1)
return makeRequest(ctx, method, url, headers, bodyCopy, regOpts)
} }
return resp, nil return resp, nil

View file

@ -277,7 +277,8 @@ func PushModelHandler(c *gin.Context) {
Password: req.Password, Password: req.Password,
} }
if err := PushModel(req.Name, regOpts, fn); err != nil { ctx := context.Background()
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()