From b0c12e24220226216cbc49ea930f2c9dcb9edadc Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Wed, 13 Dec 2017 17:02:04 +0100 Subject: [PATCH] Fix: frontend redirect --- server/server.go | 75 +++++++++------ server/server_test.go | 213 +++++++++++++++++++++++++++++++++++------- 2 files changed, 223 insertions(+), 65 deletions(-) diff --git a/server/server.go b/server/server.go index 3fa94268b..33e8d2b27 100644 --- a/server/server.go +++ b/server/server.go @@ -49,6 +49,10 @@ import ( "golang.org/x/net/http2" ) +const ( + defaultRedirectRegex = `^(?:https?:\/\/)?([\w\._-]+)(?::\d+)?(.*)$` +) + var ( httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0) ) @@ -940,7 +944,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura if entryPoint.Redirect != nil { if redirectHandlers[entryPointName] != nil { n.Use(redirectHandlers[entryPointName]) - } else if handler, err := s.loadEntryPointConfig(entryPointName, entryPoint); err != nil { + } else if handler, err := s.buildEntryPointRedirect(entryPointName, entryPoint); err != nil { log.Errorf("Error loading entrypoint configuration for frontend %s: %v", frontendName, err) log.Errorf("Skipping frontend %s...", frontendName) continue frontend @@ -1121,25 +1125,19 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura ipWhitelistMiddleware, err := configureIPWhitelistMiddleware(frontend.WhitelistSourceRange) if err != nil { - log.Fatalf("Error creating IP Whitelister: %s", err) + log.Errorf("Error creating IP Whitelister: %s", err) } else if ipWhitelistMiddleware != nil { n.Use(ipWhitelistMiddleware) log.Infof("Configured IP Whitelists: %s", frontend.WhitelistSourceRange) } if len(frontend.Redirect) > 0 { - proto := "http" - if s.globalConfiguration.EntryPoints[frontend.Redirect].TLS != nil { - proto = "https" - } - - regex, replacement, err := s.buildRedirect(proto, entryPoint) - rewrite, err := middlewares.NewRewrite(regex, replacement, true) + rewrite, err := s.buildRedirectRewrite(entryPointName, frontend.Redirect) if err != nil { - log.Fatalf("Error creating Frontend Redirect: %v", err) + log.Errorf("Error creating Frontend Redirect: %v", err) } n.Use(rewrite) - log.Debugf("Creating frontend %s redirect to %s", frontendName, proto) + log.Debugf("Frontend %s redirect created", frontendName) } if len(frontend.BasicAuth) > 0 { @@ -1289,38 +1287,57 @@ func (s *Server) wireFrontendBackend(serverRoute *serverRoute, handler http.Hand serverRoute.route.Handler(handler) } -func (s *Server) loadEntryPointConfig(entryPointName string, entryPoint *configuration.EntryPoint) (negroni.Handler, error) { +func (s *Server) buildEntryPointRedirect(srcEntryPointName string, entryPoint *configuration.EntryPoint) (*middlewares.Rewrite, error) { + if len(entryPoint.Redirect.EntryPoint) > 0 { + return s.buildRedirectRewrite(srcEntryPointName, entryPoint.Redirect.EntryPoint) + } + regex := entryPoint.Redirect.Regex replacement := entryPoint.Redirect.Replacement - var err error - if len(entryPoint.Redirect.EntryPoint) > 0 { - var protocol = "http" - if s.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint].TLS != nil { - protocol = "https" - } - regex, replacement, err = s.buildRedirect(protocol, entryPoint) - } rewrite, err := middlewares.NewRewrite(regex, replacement, true) if err != nil { return nil, err } - log.Debugf("Creating entryPoint redirect %s -> %s : %s -> %s", entryPointName, entryPoint.Redirect.EntryPoint, regex, replacement) + log.Debugf("Creating entryPoint redirect %s -> %s : %s -> %s", srcEntryPointName, entryPoint.Redirect.EntryPoint, regex, replacement) return rewrite, nil } -func (s *Server) buildRedirect(protocol string, entryPoint *configuration.EntryPoint) (string, string, error) { - regex := `^(?:https?:\/\/)?([\w\._-]+)(?::\d+)?(.*)$` - if s.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint] == nil { - return "", "", fmt.Errorf("unknown target entrypoint %q", entryPoint.Redirect.EntryPoint) +func (s *Server) buildRedirectRewrite(srcEntryPointName string, redirectEntryPoint string) (*middlewares.Rewrite, error) { + regex, replacement, err := s.buildRedirect(redirectEntryPoint) + if err != nil { + return nil, err } - r, _ := regexp.Compile(`(:\d+)`) - match := r.FindStringSubmatch(s.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint].Address) + + rewrite, err := middlewares.NewRewrite(regex, replacement, true) + if err != nil { + // Impossible case because error is always nil + return nil, err + } + log.Debugf("Creating entryPoint redirect %s -> %s : %s -> %s", srcEntryPointName, redirectEntryPoint, regex, replacement) + + return rewrite, nil +} + +func (s *Server) buildRedirect(entryPointName string) (string, string, error) { + entryPoint := s.globalConfiguration.EntryPoints[entryPointName] + if entryPoint == nil { + return "", "", fmt.Errorf("unknown target entrypoint %q", entryPointName) + } + + exp := regexp.MustCompile(`(:\d+)`) + match := exp.FindStringSubmatch(entryPoint.Address) if len(match) == 0 { - return "", "", fmt.Errorf("bad Address format %q", s.globalConfiguration.EntryPoints[entryPoint.Redirect.EntryPoint].Address) + return "", "", fmt.Errorf("bad Address format %q", entryPoint.Address) } + + var protocol = "http" + if s.globalConfiguration.EntryPoints[entryPointName].TLS != nil { + protocol = "https" + } + replacement := protocol + "://$1" + match[0] + "$2" - return regex, replacement, nil + return defaultRedirectRegex, replacement, nil } func (s *Server) buildDefaultHTTPRouter() *mux.Router { diff --git a/server/server_test.go b/server/server_test.go index c93587370..ac8737b90 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -903,55 +903,196 @@ func TestServerResponseEmptyBackend(t *testing.T) { } } -func TestServerLoadConfigBuildRedirect(t *testing.T) { +func TestBuildEntryPointRedirect(t *testing.T) { + srv := Server{ + globalConfiguration: configuration.GlobalConfiguration{ + EntryPoints: configuration.EntryPoints{ + "http": &configuration.EntryPoint{Address: ":80"}, + "https": &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}, + }, + }, + } + testCases := []struct { - desc string - replacementProtocol string - globalConfiguration configuration.GlobalConfiguration - originEntryPointName string - expectedReplacement string + desc string + srcEntryPointName string + url string + entryPoint *configuration.EntryPoint + expectedURL string }{ { - desc: "Redirect endpoint http to https with HTTPS protocol", - replacementProtocol: "https", - originEntryPointName: "http", - globalConfiguration: configuration.GlobalConfiguration{ - EntryPoints: configuration.EntryPoints{ - "http": &configuration.EntryPoint{ - Address: ":80", - Redirect: &configuration.Redirect{ - EntryPoint: "https", - }, - }, - "https": &configuration.EntryPoint{ - Address: ":443", - TLS: &tls.TLS{}, - }, + desc: "redirect regex", + srcEntryPointName: "http", + url: "http://foo.com", + entryPoint: &configuration.EntryPoint{ + Address: ":80", + Redirect: &configuration.Redirect{ + Regex: `^(?:http?:\/\/)(foo)(\.com)$`, + Replacement: "https://$1{{\"bar\"}}$2", }, }, + expectedURL: "https://foobar.com", + }, + { + desc: "redirect entry point", + srcEntryPointName: "http", + url: "http://foo:80", + entryPoint: &configuration.EntryPoint{ + Address: ":80", + Redirect: &configuration.Redirect{ + EntryPoint: "https", + }, + }, + expectedURL: "https://foo:443", + }, + { + desc: "redirect entry point with regex (ignored)", + srcEntryPointName: "http", + url: "http://foo.com:80", + entryPoint: &configuration.EntryPoint{ + Address: ":80", + Redirect: &configuration.Redirect{ + EntryPoint: "https", + Regex: `^(?:http?:\/\/)(foo)(\.com)$`, + Replacement: "https://$1{{\"bar\"}}$2", + }, + }, + expectedURL: "https://foo.com:443", + }, + } + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rewrite, err := srv.buildEntryPointRedirect(test.srcEntryPointName, test.entryPoint) + require.NoError(t, err) + + req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil) + recorder := httptest.NewRecorder() + + rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Location", "fail") + })) + + location, err := recorder.Result().Location() + require.NoError(t, err) + assert.Equal(t, test.expectedURL, location.String()) + }) + } +} + +func TestServerBuildRedirectRewrite(t *testing.T) { + srv := Server{ + globalConfiguration: configuration.GlobalConfiguration{ + EntryPoints: configuration.EntryPoints{ + "http": &configuration.EntryPoint{Address: ":80"}, + "https": &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}, + }, + }, + } + + testCases := []struct { + desc string + srcEntryPointName string + redirectEntryPoint string + url string + expectedURL string + errorExpected bool + }{ + { + desc: "existing redirect entry point", + srcEntryPointName: "http", + redirectEntryPoint: "https", + url: "http://foo:80", + expectedURL: "https://foo:443", + }, + { + desc: "non-existing redirect entry point", + srcEntryPointName: "http", + redirectEntryPoint: "foo", + url: "http://foo:80", + errorExpected: true, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rewrite, err := srv.buildRedirectRewrite(test.srcEntryPointName, test.redirectEntryPoint) + if test.errorExpected { + require.Error(t, err) + } else { + require.NoError(t, err) + + recorder := httptest.NewRecorder() + r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil) + rewrite.ServeHTTP(recorder, r, nil) + + location, err := recorder.Result().Location() + require.NoError(t, err) + + assert.Equal(t, test.expectedURL, location.String()) + } + }) + } +} + +func TestServerBuildRedirect(t *testing.T) { + testCases := []struct { + desc string + globalConfiguration configuration.GlobalConfiguration + redirectEntryPointName string + expectedReplacement string + errorExpected bool + }{ + { + desc: "Redirect endpoint http to https with HTTPS protocol", + redirectEntryPointName: "https", + globalConfiguration: configuration.GlobalConfiguration{ + EntryPoints: configuration.EntryPoints{ + "http": &configuration.EntryPoint{Address: ":80"}, + "https": &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}, + }, + }, expectedReplacement: "https://$1:443$2", }, { - desc: "Redirect endpoint http to http02 with HTTP protocol", - replacementProtocol: "http", - originEntryPointName: "http", + desc: "Redirect endpoint http to http02 with HTTP protocol", + redirectEntryPointName: "http02", globalConfiguration: configuration.GlobalConfiguration{ EntryPoints: configuration.EntryPoints{ - "http": &configuration.EntryPoint{ - Address: ":80", - Redirect: &configuration.Redirect{ - EntryPoint: "http02", - }, - }, - "http02": &configuration.EntryPoint{ - Address: ":88", - }, + "http": &configuration.EntryPoint{Address: ":80"}, + "http02": &configuration.EntryPoint{Address: ":88"}, }, }, - expectedReplacement: "http://$1:88$2", }, + { + desc: "Redirect endpoint to non-existent entry point", + redirectEntryPointName: "foobar", + globalConfiguration: configuration.GlobalConfiguration{ + EntryPoints: configuration.EntryPoints{ + "http": &configuration.EntryPoint{Address: ":80"}, + "http02": &configuration.EntryPoint{Address: ":88"}, + }, + }, + errorExpected: true, + }, + { + desc: "Redirect endpoint to an entry point with a malformed address", + redirectEntryPointName: "http02", + globalConfiguration: configuration.GlobalConfiguration{ + EntryPoints: configuration.EntryPoints{ + "http": &configuration.EntryPoint{Address: ":80"}, + "http02": &configuration.EntryPoint{Address: "88"}, + }, + }, + errorExpected: true, + }, } for _, test := range testCases { @@ -961,9 +1102,9 @@ func TestServerLoadConfigBuildRedirect(t *testing.T) { srv := Server{globalConfiguration: test.globalConfiguration} - _, replacement, err := srv.buildRedirect(test.replacementProtocol, srv.globalConfiguration.EntryPoints[test.originEntryPointName]) + _, replacement, err := srv.buildRedirect(test.redirectEntryPointName) - require.NoError(t, err, "build redirect sent an unexpected error") + require.Equal(t, test.errorExpected, err != nil, "Expected an error but don't have error, or Expected no error but have an error: %v", err) assert.Equal(t, test.expectedReplacement, replacement, "build redirect does not return the right replacement pattern") }) }