diff --git a/pkg/middlewares/redirect/redirect.go b/pkg/middlewares/redirect/redirect.go index 0e87013c3..9224a3252 100644 --- a/pkg/middlewares/redirect/redirect.go +++ b/pkg/middlewares/redirect/redirect.go @@ -132,19 +132,13 @@ func rawURL(req *http.Request) string { uri = match[4] } - if req.TLS != nil || isXForwardedHTTPS(req) { + if req.TLS != nil { scheme = "https" } return strings.Join([]string{scheme, "://", host, port, uri}, "") } -func isXForwardedHTTPS(request *http.Request) bool { - xForwardedProto := request.Header.Get("X-Forwarded-Proto") - - return len(xForwardedProto) > 0 && xForwardedProto == "https" -} - func applyString(in string, out io.Writer, req *http.Request) error { t, err := template.New("t").Parse(in) if err != nil { diff --git a/pkg/middlewares/redirect/redirect_regex_test.go b/pkg/middlewares/redirect/redirect_regex_test.go index ca4edbb79..fd4d71640 100644 --- a/pkg/middlewares/redirect/redirect_regex_test.go +++ b/pkg/middlewares/redirect/redirect_regex_test.go @@ -19,6 +19,7 @@ func TestRedirectRegexHandler(t *testing.T) { config dynamic.RedirectRegex method string url string + headers map[string]string secured bool expectedURL string expectedStatus int @@ -104,6 +105,19 @@ func TestRedirectRegexHandler(t *testing.T) { expectedURL: "https://foo:443", expectedStatus: http.StatusFound, }, + { + desc: "HTTP to HTTPS, with X-Forwarded-Proto", + config: dynamic.RedirectRegex{ + Regex: `http://foo:80`, + Replacement: "https://foo:443", + }, + url: "http://foo:80", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + }, + expectedURL: "https://foo:443", + expectedStatus: http.StatusFound, + }, { desc: "HTTPS to HTTP", config: dynamic.RedirectRegex{ @@ -171,12 +185,18 @@ func TestRedirectRegexHandler(t *testing.T) { if test.method != "" { method = test.method } - r := testhelpers.MustNewRequest(method, test.url, nil) + + req := testhelpers.MustNewRequest(method, test.url, nil) if test.secured { - r.TLS = &tls.ConnectionState{} + req.TLS = &tls.ConnectionState{} } - r.Header.Set("X-Foo", "bar") - handler.ServeHTTP(recorder, r) + + for k, v := range test.headers { + req.Header.Set(k, v) + } + + req.Header.Set("X-Foo", "bar") + handler.ServeHTTP(recorder, req) assert.Equal(t, test.expectedStatus, recorder.Code) switch test.expectedStatus { diff --git a/pkg/middlewares/redirect/redirect_scheme_test.go b/pkg/middlewares/redirect/redirect_scheme_test.go index 384780d53..b0694cd57 100644 --- a/pkg/middlewares/redirect/redirect_scheme_test.go +++ b/pkg/middlewares/redirect/redirect_scheme_test.go @@ -19,6 +19,7 @@ func TestRedirectSchemeHandler(t *testing.T) { config dynamic.RedirectScheme method string url string + headers map[string]string secured bool expectedURL string expectedStatus int @@ -39,6 +40,18 @@ func TestRedirectSchemeHandler(t *testing.T) { expectedURL: "https://foo", expectedStatus: http.StatusFound, }, + { + desc: "HTTP to HTTPS, with X-Forwarded-Proto", + config: dynamic.RedirectScheme{ + Scheme: "https", + }, + url: "http://foo", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + }, + expectedURL: "https://foo", + expectedStatus: http.StatusFound, + }, { desc: "HTTP with port to HTTPS without port", config: dynamic.RedirectScheme{ @@ -197,13 +210,17 @@ func TestRedirectSchemeHandler(t *testing.T) { if test.method != "" { method = test.method } - r := httptest.NewRequest(method, test.url, nil) + req := httptest.NewRequest(method, test.url, nil) + + for k, v := range test.headers { + req.Header.Set(k, v) + } if test.secured { - r.TLS = &tls.ConnectionState{} + req.TLS = &tls.ConnectionState{} } - r.Header.Set("X-Foo", "bar") - handler.ServeHTTP(recorder, r) + req.Header.Set("X-Foo", "bar") + handler.ServeHTTP(recorder, req) assert.Equal(t, test.expectedStatus, recorder.Code) @@ -223,9 +240,9 @@ func TestRedirectSchemeHandler(t *testing.T) { if re.Match([]byte(test.url)) { match := re.FindStringSubmatch(test.url) - r.RequestURI = match[4] + req.RequestURI = match[4] - handler.ServeHTTP(recorder, r) + handler.ServeHTTP(recorder, req) assert.Equal(t, test.expectedStatus, recorder.Code) if test.expectedStatus == http.StatusMovedPermanently ||