diff --git a/server/auth.go b/server/auth.go index 80982a43..be8ac7d0 100644 --- a/server/auth.go +++ b/server/auth.go @@ -12,8 +12,10 @@ import ( "io" "log" "net/http" + "net/url" "os" "path" + "strconv" "strings" "time" @@ -43,21 +45,34 @@ func generateNonce(length int) (string, error) { return base64.RawURLEncoding.EncodeToString(nonce), nil } -func (r AuthRedirect) URL() (string, error) { +func (r AuthRedirect) URL() (*url.URL, error) { + redirectURL, err := url.Parse(r.Realm) + if err != nil { + return nil, err + } + + values := redirectURL.Query() + + values.Add("service", r.Service) + + for _, s := range strings.Split(r.Scope, " ") { + values.Add("scope", s) + } + + values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) + nonce, err := generateNonce(16) if err != nil { - return "", err + return nil, err } - scopes := []string{} - for _, s := range strings.Split(r.Scope, " ") { - scopes = append(scopes, fmt.Sprintf("scope=%s", s)) - } - scopeStr := strings.Join(scopes, "&") - return fmt.Sprintf("%s?service=%s&%s&ts=%d&nonce=%s", r.Realm, r.Service, scopeStr, time.Now().Unix(), nonce), nil + values.Add("nonce", nonce) + + redirectURL.RawQuery = values.Encode() + return redirectURL, nil } func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *RegistryOptions) (string, error) { - url, err := redirData.URL() + redirectURL, err := redirData.URL() if err != nil { return "", err } @@ -77,18 +92,10 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry s := SignatureData{ Method: "GET", - Path: url, + Path: redirectURL.String(), Data: nil, } - if !strings.HasPrefix(s.Path, "http") { - if regOpts.Insecure { - s.Path = "http://" + url - } else { - s.Path = "https://" + url - } - } - sig, err := s.Sign(rawKey) if err != nil { return "", err @@ -96,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry headers := make(http.Header) headers.Set("Authorization", sig) - resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts) if err != nil { log.Printf("couldn't get token: %q", err) } diff --git a/server/download.go b/server/download.go index ca3427d0..67dc46ed 100644 --- a/server/download.go +++ b/server/download.go @@ -155,12 +155,13 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error { } } - url := fmt.Sprintf("%s/v2/%s/blobs/%s", opts.mp.Registry, opts.mp.GetNamespaceRepository(), f.Digest) + requestURL := opts.mp.BaseURL() + requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", f.Digest) headers := make(http.Header) headers.Set("Range", fmt.Sprintf("bytes=%d-", size)) - resp, err := makeRequest(ctx, "GET", url, headers, nil, opts.regOpts) + resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts) if err != nil { log.Printf("couldn't download blob: %v", err) return fmt.Errorf("%w: %w", errDownload, err) diff --git a/server/images.go b/server/images.go index 4c93660d..0927fd76 100644 --- a/server/images.go +++ b/server/images.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "net/url" "os" "path" "path/filepath" @@ -961,8 +962,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return err } - if strings.HasPrefix(path.Base(location), "sha256:") { - layer.Digest = path.Base(location) + if strings.HasPrefix(path.Base(location.Path), "sha256:") { + layer.Digest = path.Base(location.Path) fn(api.ProgressResponse{ Status: "using existing layer", Digest: layer.Digest, @@ -979,7 +980,8 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu } fn(api.ProgressResponse{Status: "pushing manifest"}) - url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag) + requestURL := mp.BaseURL() + requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) manifestJSON, err := json.Marshal(manifest) if err != nil { @@ -988,7 +990,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu headers := make(http.Header) headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") - resp, err := makeRequestWithRetry(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts) + resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts) if err != nil { return err } @@ -1072,11 +1074,11 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu } func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { - url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag) + requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") - resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts) if err != nil { log.Printf("couldn't get manifest: %v", err) return nil, err @@ -1137,33 +1139,38 @@ func GetSHA256Digest(r io.Reader) (string, int) { type requestContextKey string -func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) { - url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) +func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, error) { + requestURL := mp.BaseURL() + requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") if layer.From != "" { - url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From) + values := requestURL.Query() + values.Add("mount", layer.Digest) + values.Add("from", layer.From) + requestURL.RawQuery = values.Encode() } - resp, err := makeRequestWithRetry(ctx, "POST", url, nil, nil, regOpts) + resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) - return "", err + return nil, err } defer resp.Body.Close() // Extract UUID location from header location := resp.Header.Get("Location") if location == "" { - return "", fmt.Errorf("location header is missing in response") + return nil, fmt.Errorf("location header is missing in response") } - return location, nil + return url.Parse(location) } // Function to check if a blob already exists in the Docker registry func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) { - url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest) + requestURL := mp.BaseURL() + requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest) - resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts) + resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts) if err != nil { log.Printf("couldn't check for blob: %v", err) return false, err @@ -1174,7 +1181,7 @@ func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpt return resp.StatusCode == http.StatusOK, nil } -func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func uploadBlobChunked(ctx context.Context, mp ModelPath, requestURL *url.URL, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { // TODO allow resumability // TODO allow canceling uploads via DELETE @@ -1204,7 +1211,7 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", strconv.Itoa(int(chunk))) headers.Set("Content-Range", fmt.Sprintf("%d-%d", completed, completed+sectionReader.Size()-1)) - resp, err := makeRequestWithRetry(ctx, "PATCH", url, headers, sectionReader, regOpts) + resp, err := makeRequestWithRetry(ctx, "PATCH", requestURL, headers, sectionReader, regOpts) if err != nil && !errors.Is(err, io.EOF) { fn(api.ProgressResponse{ Status: fmt.Sprintf("error uploading chunk: %v", err), @@ -1225,20 +1232,26 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay Completed: int(completed), }) - url = resp.Header.Get("Location") + requestURL, err = url.Parse(resp.Header.Get("Location")) + if err != nil { + return err + } + if completed >= int64(layer.Size) { break } } - url = fmt.Sprintf("%s&digest=%s", url, layer.Digest) + values := requestURL.Query() + values.Add("digest", layer.Digest) + requestURL.RawQuery = values.Encode() headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", "0") // finish the upload - resp, err := makeRequest(ctx, "PUT", url, headers, nil, regOpts) + resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts) if err != nil { log.Printf("couldn't finish upload: %v", err) return err @@ -1252,10 +1265,10 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay return nil } -func makeRequestWithRetry(ctx context.Context, method, url string, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { +func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { var status string for try := 0; try < MaxRetries; try++ { - resp, err := makeRequest(ctx, method, url, headers, body, regOpts) + resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) return nil, err @@ -1291,16 +1304,12 @@ func makeRequestWithRetry(ctx context.Context, method, url string, headers http. return nil, fmt.Errorf("max retry exceeded: %v", status) } -func makeRequest(ctx context.Context, method, url string, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { - if !strings.HasPrefix(url, "http") { - if regOpts.Insecure { - url = "http://" + url - } else { - url = "https://" + url - } +func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { + if requestURL.Scheme != "http" && regOpts.Insecure { + requestURL.Scheme = "http" } - req, err := http.NewRequestWithContext(ctx, method, url, body) + req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body) if err != nil { return nil, err } diff --git a/server/modelpath.go b/server/modelpath.go index 0fe67211..aa83e7a6 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -3,6 +3,7 @@ package server import ( "errors" "fmt" + "net/url" "os" "path/filepath" "runtime" @@ -39,13 +40,13 @@ func ParseModelPath(name string) ModelPath { Tag: DefaultTag, } - parts := strings.Split(name, "://") - if len(parts) > 1 { - mp.ProtocolScheme = parts[0] - name = parts[1] + before, after, found := strings.Cut(name, "://") + if found { + mp.ProtocolScheme = before + name = after } - parts = strings.Split(name, "/") + parts := strings.Split(name, "/") switch len(parts) { case 3: mp.Registry = parts[0] @@ -100,6 +101,13 @@ func (mp ModelPath) GetManifestPath(createDir bool) (string, error) { return path, nil } +func (mp ModelPath) BaseURL() *url.URL { + return &url.URL{ + Scheme: mp.ProtocolScheme, + Host: mp.Registry, + } +} + func GetManifestPath() (string, error) { home, err := os.UserHomeDir() if err != nil {