diff --git a/cmd/traefik/traefik.go b/cmd/traefik/traefik.go index 9fe1854fc..eb7637dab 100644 --- a/cmd/traefik/traefik.go +++ b/cmd/traefik/traefik.go @@ -21,6 +21,7 @@ import ( cmdVersion "github.com/containous/traefik/cmd/version" "github.com/containous/traefik/collector" "github.com/containous/traefik/configuration" + "github.com/containous/traefik/configuration/router" "github.com/containous/traefik/job" "github.com/containous/traefik/log" "github.com/containous/traefik/provider/acme" @@ -177,12 +178,31 @@ func runCmd(globalConfiguration *configuration.GlobalConfiguration, configFile s store := acme.NewLocalStore(acme.Get().Storage) acme.Get().Store = &store } - svr := server.NewServer(*globalConfiguration, configuration.NewProviderAggregator(globalConfiguration)) + + entryPoints := map[string]server.EntryPoint{} + for entryPointName, config := range globalConfiguration.EntryPoints { + internalRouter := router.NewInternalRouterAggregator(*globalConfiguration, entryPointName) + if acme.IsEnabled() && acme.Get().HTTPChallenge != nil && acme.Get().HTTPChallenge.EntryPoint == entryPointName { + internalRouter.AddRouter(acme.Get()) + } + + entryPoints[entryPointName] = server.EntryPoint{ + InternalRouter: internalRouter, + Configuration: config, + } + } + + svr := server.NewServer(*globalConfiguration, configuration.NewProviderAggregator(globalConfiguration), entryPoints) if acme.IsEnabled() && acme.Get().OnHostRule { acme.Get().SetConfigListenerChan(make(chan types.Configuration)) svr.AddListener(acme.Get().ListenConfiguration) } ctx := cmd.ContextWithSignal(context.Background()) + + if globalConfiguration.Ping != nil { + globalConfiguration.Ping.WithContext(ctx) + } + svr.StartWithContext(ctx) defer svr.Close() diff --git a/configuration/configuration_test.go b/configuration/configuration_test.go index e8c5e7b47..b9e7206d5 100644 --- a/configuration/configuration_test.go +++ b/configuration/configuration_test.go @@ -58,7 +58,6 @@ func TestSetEffectiveConfigurationGraceTimeout(t *testing.T) { gc.SetEffectiveConfiguration(defaultConfigFile) assert.Equal(t, test.wantGraceTimeout, time.Duration(gc.LifeCycle.GraceTimeOut)) - }) } } diff --git a/configuration/router/internal_router.go b/configuration/router/internal_router.go new file mode 100644 index 000000000..0a4ca50df --- /dev/null +++ b/configuration/router/internal_router.go @@ -0,0 +1,132 @@ +package router + +import ( + "github.com/containous/mux" + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/log" + "github.com/containous/traefik/metrics" + "github.com/containous/traefik/middlewares" + mauth "github.com/containous/traefik/middlewares/auth" + "github.com/containous/traefik/types" + "github.com/urfave/negroni" +) + +// NewInternalRouterAggregator Create a new internalRouterAggregator +func NewInternalRouterAggregator(globalConfiguration configuration.GlobalConfiguration, entryPointName string) *InternalRouterAggregator { + var serverMiddlewares []negroni.Handler + + if globalConfiguration.EntryPoints[entryPointName].WhiteList != nil { + ipWhitelistMiddleware, err := middlewares.NewIPWhiteLister( + globalConfiguration.EntryPoints[entryPointName].WhiteList.SourceRange, + globalConfiguration.EntryPoints[entryPointName].WhiteList.UseXForwardedFor) + if err != nil { + log.Fatalf("Error creating whitelist middleware: %s", err) + } + if ipWhitelistMiddleware != nil { + serverMiddlewares = append(serverMiddlewares, ipWhitelistMiddleware) + } + } + + if globalConfiguration.EntryPoints[entryPointName].Auth != nil { + authMiddleware, err := mauth.NewAuthenticator(globalConfiguration.EntryPoints[entryPointName].Auth, nil) + if err != nil { + log.Fatalf("Error creating authenticator middleware: %s", err) + } + serverMiddlewares = append(serverMiddlewares, authMiddleware) + } + + router := InternalRouterAggregator{} + routerWithPrefix := InternalRouterAggregator{} + routerWithPrefixAndMiddleware := InternalRouterAggregator{} + + if globalConfiguration.Metrics != nil && globalConfiguration.Metrics.Prometheus != nil && globalConfiguration.Metrics.Prometheus.EntryPoint == entryPointName { + routerWithPrefixAndMiddleware.AddRouter(metrics.PrometheusHandler{}) + } + + if globalConfiguration.Rest != nil && globalConfiguration.Rest.EntryPoint == entryPointName { + routerWithPrefixAndMiddleware.AddRouter(globalConfiguration.Rest) + } + + if globalConfiguration.API != nil && globalConfiguration.API.EntryPoint == entryPointName { + routerWithPrefixAndMiddleware.AddRouter(globalConfiguration.API) + } + + if globalConfiguration.Ping != nil && globalConfiguration.Ping.EntryPoint == entryPointName { + routerWithPrefix.AddRouter(globalConfiguration.Ping) + } + + if globalConfiguration.ACME != nil && globalConfiguration.ACME.HTTPChallenge != nil && globalConfiguration.ACME.HTTPChallenge.EntryPoint == entryPointName { + router.AddRouter(globalConfiguration.ACME) + } + + realRouterWithMiddleware := WithMiddleware{router: &routerWithPrefixAndMiddleware, routerMiddlewares: serverMiddlewares} + if globalConfiguration.Web != nil && globalConfiguration.Web.Path != "" { + router.AddRouter(&WithPrefix{PathPrefix: globalConfiguration.Web.Path, Router: &routerWithPrefix}) + router.AddRouter(&WithPrefix{PathPrefix: globalConfiguration.Web.Path, Router: &realRouterWithMiddleware}) + } else { + router.AddRouter(&routerWithPrefix) + router.AddRouter(&realRouterWithMiddleware) + } + + return &router +} + +// WithMiddleware router with internal middleware +type WithMiddleware struct { + router types.InternalRouter + routerMiddlewares []negroni.Handler +} + +// AddRoutes Add routes to the router +func (wm *WithMiddleware) AddRoutes(systemRouter *mux.Router) { + realRouter := systemRouter.PathPrefix("/").Subrouter() + + wm.router.AddRoutes(realRouter) + + if len(wm.routerMiddlewares) > 0 { + realRouter.Walk(wrapRoute(wm.routerMiddlewares)) + } +} + +// WithPrefix router which add a prefix +type WithPrefix struct { + Router types.InternalRouter + PathPrefix string +} + +// AddRoutes Add routes to the router +func (wp *WithPrefix) AddRoutes(systemRouter *mux.Router) { + realRouter := systemRouter.PathPrefix("/").Subrouter() + if wp.PathPrefix != "" { + realRouter = systemRouter.PathPrefix(wp.PathPrefix).Subrouter() + realRouter.StrictSlash(true) + realRouter.SkipClean(true) + } + wp.Router.AddRoutes(realRouter) +} + +// InternalRouterAggregator InternalRouter that aggregate other internalRouter +type InternalRouterAggregator struct { + internalRouters []types.InternalRouter +} + +// AddRouter add a router in the aggregator +func (r *InternalRouterAggregator) AddRouter(router types.InternalRouter) { + r.internalRouters = append(r.internalRouters, router) +} + +// AddRoutes Add routes to the router +func (r *InternalRouterAggregator) AddRoutes(systemRouter *mux.Router) { + for _, router := range r.internalRouters { + router.AddRoutes(systemRouter) + } +} + +// wrapRoute with middlewares +func wrapRoute(middlewares []negroni.Handler) func(*mux.Route, *mux.Router, []*mux.Route) error { + return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + middles := append(middlewares, negroni.Wrap(route.GetHandler())) + route.Handler(negroni.New(middles...)) + return nil + } +} diff --git a/configuration/router/internal_router_test.go b/configuration/router/internal_router_test.go new file mode 100644 index 000000000..eae7aa3b1 --- /dev/null +++ b/configuration/router/internal_router_test.go @@ -0,0 +1,346 @@ +package router + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/containous/mux" + "github.com/containous/traefik/acme" + "github.com/containous/traefik/api" + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/ping" + acmeprovider "github.com/containous/traefik/provider/acme" + "github.com/containous/traefik/safe" + "github.com/containous/traefik/types" + "github.com/stretchr/testify/assert" + "github.com/urfave/negroni" +) + +func TestNewInternalRouterAggregatorWithWebPath(t *testing.T) { + currentConfiguration := &safe.Safe{} + currentConfiguration.Set(types.Configurations{}) + + globalConfiguration := configuration.GlobalConfiguration{ + Web: &configuration.WebCompatibility{ + Path: "/prefix", + }, + API: &api.Handler{ + EntryPoint: "traefik", + CurrentConfigurations: currentConfiguration, + }, + Ping: &ping.Handler{ + EntryPoint: "traefik", + }, + ACME: &acme.ACME{ + HTTPChallenge: &acmeprovider.HTTPChallenge{ + EntryPoint: "traefik", + }, + }, + EntryPoints: configuration.EntryPoints{ + "traefik": &configuration.EntryPoint{}, + }, + } + + testCases := []struct { + desc string + testedURL string + expectedStatusCode int + }{ + { + desc: "Ping without prefix", + testedURL: "/ping", + expectedStatusCode: 502, + }, + { + desc: "Ping with prefix", + testedURL: "/prefix/ping", + expectedStatusCode: 200, + }, + { + desc: "acme without prefix", + testedURL: "/.well-known/acme-challenge/token", + expectedStatusCode: 404, + }, + { + desc: "api without prefix", + testedURL: "/api", + expectedStatusCode: 502, + }, + { + desc: "api with prefix", + testedURL: "/prefix/api", + expectedStatusCode: 200, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + router := NewInternalRouterAggregator(globalConfiguration, "traefik") + + internalMuxRouter := mux.NewRouter() + router.AddRoutes(internalMuxRouter) + internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, test.testedURL, nil) + internalMuxRouter.ServeHTTP(recorder, request) + + assert.Equal(t, test.expectedStatusCode, recorder.Code) + }) + } +} + +func TestNewInternalRouterAggregatorWithAuth(t *testing.T) { + currentConfiguration := &safe.Safe{} + currentConfiguration.Set(types.Configurations{}) + + globalConfiguration := configuration.GlobalConfiguration{ + API: &api.Handler{ + EntryPoint: "traefik", + CurrentConfigurations: currentConfiguration, + }, + Ping: &ping.Handler{ + EntryPoint: "traefik", + }, + ACME: &acme.ACME{ + HTTPChallenge: &acmeprovider.HTTPChallenge{ + EntryPoint: "traefik", + }, + }, + EntryPoints: configuration.EntryPoints{ + "traefik": &configuration.EntryPoint{ + Auth: &types.Auth{ + Basic: &types.Basic{ + Users: types.Users{"test:test"}, + }, + }, + }, + }, + } + + testCases := []struct { + desc string + testedURL string + expectedStatusCode int + }{ + { + desc: "Wrong url", + testedURL: "/wrong", + expectedStatusCode: 502, + }, + { + desc: "Ping without auth", + testedURL: "/ping", + expectedStatusCode: 200, + }, + { + desc: "acme without auth", + testedURL: "/.well-known/acme-challenge/token", + expectedStatusCode: 404, + }, + { + desc: "api with auth", + testedURL: "/api", + expectedStatusCode: 401, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + router := NewInternalRouterAggregator(globalConfiguration, "traefik") + + internalMuxRouter := mux.NewRouter() + router.AddRoutes(internalMuxRouter) + internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, test.testedURL, nil) + internalMuxRouter.ServeHTTP(recorder, request) + + assert.Equal(t, test.expectedStatusCode, recorder.Code) + }) + } +} + +func TestNewInternalRouterAggregatorWithAuthAndPrefix(t *testing.T) { + currentConfiguration := &safe.Safe{} + currentConfiguration.Set(types.Configurations{}) + + globalConfiguration := configuration.GlobalConfiguration{ + Web: &configuration.WebCompatibility{ + Path: "/prefix", + }, + API: &api.Handler{ + EntryPoint: "traefik", + CurrentConfigurations: currentConfiguration, + }, + Ping: &ping.Handler{ + EntryPoint: "traefik", + }, + ACME: &acme.ACME{ + HTTPChallenge: &acmeprovider.HTTPChallenge{ + EntryPoint: "traefik", + }, + }, + EntryPoints: configuration.EntryPoints{ + "traefik": &configuration.EntryPoint{ + Auth: &types.Auth{ + Basic: &types.Basic{ + Users: types.Users{"test:test"}, + }, + }, + }, + }, + } + + testCases := []struct { + desc string + testedURL string + expectedStatusCode int + }{ + { + desc: "Ping without prefix", + testedURL: "/ping", + expectedStatusCode: 502, + }, + { + desc: "Ping without auth and with prefix", + testedURL: "/prefix/ping", + expectedStatusCode: 200, + }, + { + desc: "acme without auth and without prefix", + testedURL: "/.well-known/acme-challenge/token", + expectedStatusCode: 404, + }, + { + desc: "api with auth and prefix", + testedURL: "/prefix/api", + expectedStatusCode: 401, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + router := NewInternalRouterAggregator(globalConfiguration, "traefik") + + internalMuxRouter := mux.NewRouter() + router.AddRoutes(internalMuxRouter) + internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + }) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, test.testedURL, nil) + internalMuxRouter.ServeHTTP(recorder, request) + + assert.Equal(t, test.expectedStatusCode, recorder.Code) + }) + } +} + +type MockInternalRouterFunc func(systemRouter *mux.Router) + +func (m MockInternalRouterFunc) AddRoutes(systemRouter *mux.Router) { + m(systemRouter) +} + +func TestWithMiddleware(t *testing.T) { + router := WithMiddleware{ + router: MockInternalRouterFunc(func(systemRouter *mux.Router) { + systemRouter.Handle("/test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("router")) + })) + }), + routerMiddlewares: []negroni.Handler{ + negroni.HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + rw.Write([]byte("before middleware1|")) + next.ServeHTTP(rw, r) + rw.Write([]byte("|after middleware1")) + + }), + negroni.HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + rw.Write([]byte("before middleware2|")) + next.ServeHTTP(rw, r) + rw.Write([]byte("|after middleware2")) + }), + }, + } + + internalMuxRouter := mux.NewRouter() + router.AddRoutes(internalMuxRouter) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/test", nil) + internalMuxRouter.ServeHTTP(recorder, request) + + obtained := string(recorder.Body.Bytes()) + + assert.Equal(t, "before middleware1|before middleware2|router|after middleware2|after middleware1", obtained) +} + +func TestWithPrefix(t *testing.T) { + testCases := []struct { + desc string + prefix string + testedURL string + expectedStatusCode int + }{ + { + desc: "No prefix", + testedURL: "/test", + expectedStatusCode: 200, + }, + { + desc: "With prefix and wrong url", + prefix: "/prefix", + testedURL: "/test", + expectedStatusCode: 404, + }, + { + desc: "With prefix", + prefix: "/prefix", + testedURL: "/prefix/test", + expectedStatusCode: 200, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + router := WithPrefix{ + Router: MockInternalRouterFunc(func(systemRouter *mux.Router) { + systemRouter.Handle("/test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + }), + + PathPrefix: test.prefix, + } + internalMuxRouter := mux.NewRouter() + router.AddRoutes(internalMuxRouter) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, test.testedURL, nil) + internalMuxRouter.ServeHTTP(recorder, request) + + assert.Equal(t, test.expectedStatusCode, recorder.Code) + }) + } +} diff --git a/ping/ping.go b/ping/ping.go index 7f5c51e8b..1e7ffa860 100644 --- a/ping/ping.go +++ b/ping/ping.go @@ -1,9 +1,9 @@ package ping import ( + "context" "fmt" "net/http" - "sync" "github.com/containous/mux" ) @@ -12,26 +12,22 @@ import ( type Handler struct { EntryPoint string `description:"Ping entryPoint" export:"true"` terminating bool - lock sync.RWMutex } -// SetTerminating causes the ping endpoint to serve non 200 responses. -func (g *Handler) SetTerminating() { - g.lock.Lock() - defer g.lock.Unlock() - - g.terminating = true +// WithContext causes the ping endpoint to serve non 200 responses. +func (h *Handler) WithContext(ctx context.Context) { + go func() { + <-ctx.Done() + h.terminating = true + }() } // AddRoutes add ping routes on a router -func (g *Handler) AddRoutes(router *mux.Router) { +func (h *Handler) AddRoutes(router *mux.Router) { router.Methods(http.MethodGet, http.MethodHead).Path("/ping"). HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - g.lock.RLock() - defer g.lock.RUnlock() - statusCode := http.StatusOK - if g.terminating { + if h.terminating { statusCode = http.StatusServiceUnavailable } response.WriteHeader(statusCode) diff --git a/server/server.go b/server/server.go index 145f9d8c6..ae16ccd28 100644 --- a/server/server.go +++ b/server/server.go @@ -24,6 +24,7 @@ import ( "github.com/containous/mux" "github.com/containous/traefik/cluster" "github.com/containous/traefik/configuration" + "github.com/containous/traefik/configuration/router" "github.com/containous/traefik/healthcheck" "github.com/containous/traefik/log" "github.com/containous/traefik/metrics" @@ -75,6 +76,13 @@ type Server struct { metricsRegistry metrics.Registry provider provider.Provider configurationListeners []func(types.Configuration) + entryPoints map[string]EntryPoint +} + +// EntryPoint entryPoint information (configuration + internalRouter) +type EntryPoint struct { + InternalRouter types.InternalRouter + Configuration *configuration.EntryPoint } type serverEntryPoints map[string]*serverEntryPoint @@ -88,9 +96,10 @@ type serverEntryPoint struct { } // NewServer returns an initialized Server. -func NewServer(globalConfiguration configuration.GlobalConfiguration, provider provider.Provider) *Server { +func NewServer(globalConfiguration configuration.GlobalConfiguration, provider provider.Provider, entrypoints map[string]EntryPoint) *Server { server := new(Server) + server.entryPoints = entrypoints server.provider = provider server.serverEntryPoints = make(map[string]*serverEntryPoint) server.configurationChan = make(chan types.ConfigMessage, 100) @@ -210,9 +219,6 @@ func (s *Server) StartWithContext(ctx context.Context) { <-ctx.Done() log.Info("I have to go...") reqAcceptGraceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.RequestAcceptGraceTimeout) - if s.globalConfiguration.Ping != nil && reqAcceptGraceTimeOut > 0 { - s.globalConfiguration.Ping.SetTerminating() - } if reqAcceptGraceTimeOut > 0 { log.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut) time.Sleep(reqAcceptGraceTimeOut) @@ -291,17 +297,16 @@ func (s *Server) stopLeadership() { } func (s *Server) startHTTPServers() { - s.serverEntryPoints = s.buildEntryPoints(s.globalConfiguration) + s.serverEntryPoints = s.buildEntryPoints() for newServerEntryPointName, newServerEntryPoint := range s.serverEntryPoints { serverEntryPoint := s.setupServerEntryPoint(newServerEntryPointName, newServerEntryPoint) - go s.startServer(serverEntryPoint, s.globalConfiguration) + go s.startServer(serverEntryPoint) } } func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint { serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()} - serverInternalMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()} if s.tracingMiddleware.IsEnabled() { serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(newServerEntryPointName)) @@ -310,9 +315,11 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer if s.accessLoggerMiddleware != nil { serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware) } + if s.metricsRegistry.IsEnabled() { serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, newServerEntryPointName)) } + if s.globalConfiguration.API != nil { if s.globalConfiguration.API.Stats == nil { s.globalConfiguration.API.Stats = thoas_stats.New() @@ -326,34 +333,33 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer } } - if s.globalConfiguration.EntryPoints[newServerEntryPointName].Auth != nil { - authMiddleware, err := mauth.NewAuthenticator(s.globalConfiguration.EntryPoints[newServerEntryPointName].Auth, s.tracingMiddleware) + if s.entryPoints[newServerEntryPointName].Configuration.Auth != nil { + authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[newServerEntryPointName].Configuration.Auth, s.tracingMiddleware) if err != nil { log.Fatal("Error starting server: ", err) } serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", newServerEntryPointName))) - serverInternalMiddlewares = append(serverInternalMiddlewares, authMiddleware) } - if s.globalConfiguration.EntryPoints[newServerEntryPointName].Compress { + if s.entryPoints[newServerEntryPointName].Configuration.Compress { serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{}) } ipWhitelistMiddleware, err := buildIPWhiteLister( - s.globalConfiguration.EntryPoints[newServerEntryPointName].WhiteList, - s.globalConfiguration.EntryPoints[newServerEntryPointName].WhitelistSourceRange) + s.entryPoints[newServerEntryPointName].Configuration.WhiteList, + s.entryPoints[newServerEntryPointName].Configuration.WhitelistSourceRange) if err != nil { log.Fatal("Error starting server: ", err) } if ipWhitelistMiddleware != nil { serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", newServerEntryPointName))) - serverInternalMiddlewares = append(serverInternalMiddlewares, ipWhitelistMiddleware) } - newSrv, listener, err := s.prepareServer(newServerEntryPointName, s.globalConfiguration.EntryPoints[newServerEntryPointName], newServerEntryPoint.httpRouter, serverMiddlewares, serverInternalMiddlewares) + newSrv, listener, err := s.prepareServer(newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter, serverMiddlewares) if err != nil { log.Fatal("Error preparing server: ", err) } + serverEntryPoint := s.serverEntryPoints[newServerEntryPointName] serverEntryPoint.httpServer = newSrv serverEntryPoint.listener = listener @@ -467,7 +473,7 @@ func (s *Server) loadConfiguration(configMsg types.ConfigMessage) { s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix())) for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints { s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler()) - if s.globalConfiguration.EntryPoints[newServerEntryPointName].TLS == nil { + if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil { if newServerEntryPoint.certs.Get() != nil { log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName) } @@ -550,7 +556,7 @@ func (s *Server) postLoadConfiguration() { // and is configured with ACME acmeEnabled := false for _, entryPoint := range frontend.EntryPoints { - if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.globalConfiguration.EntryPoints[entryPoint].TLS != nil { + if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil { acmeEnabled = true break } @@ -665,8 +671,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL if entryPointName == s.globalConfiguration.ACME.EntryPoint { checkOnDemandDomain := func(domain string) bool { routeMatch := &mux.RouteMatch{} - router := router.GetHandler() - match := router.Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch) + match := router.GetHandler().Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch) if match && routeMatch.Route != nil { return true } @@ -699,15 +704,16 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL } // Set the minimum TLS version if set in the config TOML - if minConst, exists := traefiktls.MinVersion[s.globalConfiguration.EntryPoints[entryPointName].TLS.MinVersion]; exists { + if minConst, exists := traefiktls.MinVersion[s.entryPoints[entryPointName].Configuration.TLS.MinVersion]; exists { config.PreferServerCipherSuites = true config.MinVersion = minConst } + // Set the list of CipherSuites if set in the config TOML - if s.globalConfiguration.EntryPoints[entryPointName].TLS.CipherSuites != nil { + if s.entryPoints[entryPointName].Configuration.TLS.CipherSuites != nil { // if our list of CipherSuites is defined in the entrypoint config, we can re-initilize the suites list as empty config.CipherSuites = make([]uint16, 0) - for _, cipher := range s.globalConfiguration.EntryPoints[entryPointName].TLS.CipherSuites { + for _, cipher := range s.entryPoints[entryPointName].Configuration.TLS.CipherSuites { if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists { config.CipherSuites = append(config.CipherSuites, cipherConst) } else { @@ -719,7 +725,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL return config, nil } -func (s *Server) startServer(serverEntryPoint *serverEntryPoint, globalConfiguration configuration.GlobalConfiguration) { +func (s *Server) startServer(serverEntryPoint *serverEntryPoint) { log.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr) var err error if serverEntryPoint.httpServer.TLSConfig != nil { @@ -732,39 +738,7 @@ func (s *Server) startServer(serverEntryPoint *serverEntryPoint, globalConfigura } } -func (s *Server) addInternalRoutes(entryPointName string, router *mux.Router) { - if s.globalConfiguration.Metrics != nil && s.globalConfiguration.Metrics.Prometheus != nil && s.globalConfiguration.Metrics.Prometheus.EntryPoint == entryPointName { - metrics.PrometheusHandler{}.AddRoutes(router) - } - - if s.globalConfiguration.Rest != nil && s.globalConfiguration.Rest.EntryPoint == entryPointName { - s.globalConfiguration.Rest.AddRoutes(router) - } - - if s.globalConfiguration.API != nil && s.globalConfiguration.API.EntryPoint == entryPointName { - s.globalConfiguration.API.AddRoutes(router) - } -} - -func (s *Server) addInternalPublicRoutes(entryPointName string, router *mux.Router) { - if s.globalConfiguration.Ping != nil && s.globalConfiguration.Ping.EntryPoint != "" && s.globalConfiguration.Ping.EntryPoint == entryPointName { - s.globalConfiguration.Ping.AddRoutes(router) - } - - if s.globalConfiguration.API != nil && s.globalConfiguration.API.EntryPoint == entryPointName && s.leadership != nil { - s.leadership.AddRoutes(router) - } -} - -func (s *Server) addACMERoutes(entryPointName string, router *mux.Router) { - if s.globalConfiguration.ACME != nil && s.globalConfiguration.ACME.HTTPChallenge != nil && s.globalConfiguration.ACME.HTTPChallenge.EntryPoint == entryPointName { - s.globalConfiguration.ACME.AddRoutes(router) - } else if acme.IsEnabled() && acme.Get().HTTPChallenge != nil && acme.Get().HTTPChallenge.EntryPoint == entryPointName { - acme.Get().AddRoutes(router) - } -} - -func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares []negroni.Handler, internalMiddlewares []negroni.Handler) (*http.Server, net.Listener, error) { +func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares []negroni.Handler) (*http.Server, net.Listener, error) { readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(s.globalConfiguration) log.Infof("Preparing server %s %+v with readTimeout=%s writeTimeout=%s idleTimeout=%s", entryPointName, entryPoint, readTimeout, writeTimeout, idleTimeout) @@ -775,12 +749,7 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration. } n.UseHandler(router) - path := "/" - if s.globalConfiguration.Web != nil && s.globalConfiguration.Web.Path != "" { - path = s.globalConfiguration.Web.Path - } - - internalMuxRouter := s.buildInternalRouter(entryPointName, path, internalMiddlewares) + internalMuxRouter := s.buildInternalRouter(entryPointName) internalMuxRouter.NotFoundHandler = n tlsConfig, err := s.createTLSConfig(entryPointName, entryPoint.TLS, router) @@ -826,34 +795,27 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration. nil } -func (s *Server) buildInternalRouter(entryPointName, path string, internalMiddlewares []negroni.Handler) *mux.Router { +func (s *Server) buildInternalRouter(entryPointName string) *mux.Router { internalMuxRouter := mux.NewRouter() internalMuxRouter.StrictSlash(true) internalMuxRouter.SkipClean(true) - internalMuxSubrouter := internalMuxRouter.PathPrefix(path).Subrouter() - internalMuxSubrouter.StrictSlash(true) - internalMuxSubrouter.SkipClean(true) + if entryPoint, ok := s.entryPoints[entryPointName]; ok && entryPoint.InternalRouter != nil { + entryPoint.InternalRouter.AddRoutes(internalMuxRouter) - s.addInternalRoutes(entryPointName, internalMuxSubrouter) - internalMuxRouter.Walk(wrapRoute(internalMiddlewares)) - - s.addInternalPublicRoutes(entryPointName, internalMuxSubrouter) - - s.addACMERoutes(entryPointName, internalMuxRouter) + if s.globalConfiguration.API != nil && s.globalConfiguration.API.EntryPoint == entryPointName && s.leadership != nil { + if s.globalConfiguration.Web != nil && s.globalConfiguration.Web.Path != "" { + rt := router.WithPrefix{Router: s.leadership, PathPrefix: s.globalConfiguration.Web.Path} + rt.AddRoutes(internalMuxRouter) + } else { + s.leadership.AddRoutes(internalMuxRouter) + } + } + } return internalMuxRouter } -// wrapRoute with middlewares -func wrapRoute(middlewares []negroni.Handler) func(*mux.Route, *mux.Router, []*mux.Route) error { - return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - middles := append(middlewares, negroni.Wrap(route.GetHandler())) - route.Handler(negroni.New(middles...)) - return nil - } -} - func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTimeout, writeTimeout, idleTimeout time.Duration) { readTimeout = time.Duration(0) writeTimeout = time.Duration(0) @@ -875,12 +837,11 @@ func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTi return readTimeout, writeTimeout, idleTimeout } -func (s *Server) buildEntryPoints(globalConfiguration configuration.GlobalConfiguration) map[string]*serverEntryPoint { +func (s *Server) buildEntryPoints() map[string]*serverEntryPoint { serverEntryPoints := make(map[string]*serverEntryPoint) - for entryPointName := range globalConfiguration.EntryPoints { - router := s.buildDefaultHTTPRouter() + for entryPointName := range s.entryPoints { serverEntryPoints[entryPointName] = &serverEntryPoint{ - httpRouter: middlewares.NewHandlerSwitcher(router), + httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()), } } return serverEntryPoints @@ -907,7 +868,7 @@ func (s *Server) getRoundTripper(entryPointName string, globalConfiguration conf // loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic // provider configurations. func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]*serverEntryPoint, error) { - serverEntryPoints := s.buildEntryPoints(globalConfiguration) + serverEntryPoints := s.buildEntryPoints() redirectHandlers := make(map[string]negroni.Handler) backends := map[string]http.Handler{} backendsHealthCheck := map[string]*healthcheck.BackendHealthCheck{} @@ -952,7 +913,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura log.Debugf("Creating route %s %s", routeName, route.Rule) } - entryPoint := globalConfiguration.EntryPoints[entryPointName] + entryPoint := s.entryPoints[entryPointName].Configuration n := negroni.New() if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint { if redirectHandlers[entryPointName] != nil { @@ -1341,7 +1302,7 @@ func (s *Server) wireFrontendBackend(serverRoute *types.ServerRoute, handler htt func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) { // entry point redirect if len(opt.EntryPoint) > 0 { - entryPoint := s.globalConfiguration.EntryPoints[opt.EntryPoint] + entryPoint := s.entryPoints[opt.EntryPoint].Configuration if entryPoint == nil { return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName) } @@ -1360,11 +1321,11 @@ func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redir } func (s *Server) buildDefaultHTTPRouter() *mux.Router { - router := mux.NewRouter() - router.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(notFoundHandler), "backend not found") - router.StrictSlash(true) - router.SkipClean(true) - return router + rt := mux.NewRouter() + rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(notFoundHandler), "backend not found") + rt.StrictSlash(true) + rt.SkipClean(true) + return rt } func parseHealthCheckOptions(lb healthcheck.LoadBalancer, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options { diff --git a/server/server_test.go b/server/server_test.go index 71c8cbca3..ddd8070ce 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,7 +21,6 @@ import ( "github.com/containous/traefik/testhelpers" "github.com/containous/traefik/tls" "github.com/containous/traefik/types" - "github.com/davecgh/go-spew/spew" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/unrolled/secure" @@ -82,12 +81,12 @@ func (lb *testLoadBalancer) Servers() []*url.URL { } func TestPrepareServerTimeouts(t *testing.T) { - tests := []struct { - desc string - globalConfig configuration.GlobalConfiguration - wantIdleTimeout time.Duration - wantReadTimeout time.Duration - wantWriteTimeout time.Duration + testCases := []struct { + desc string + globalConfig configuration.GlobalConfiguration + expectedIdleTimeout time.Duration + expectedReadTimeout time.Duration + expectedWriteTimeout time.Duration }{ { desc: "full configuration", @@ -98,25 +97,25 @@ func TestPrepareServerTimeouts(t *testing.T) { WriteTimeout: flaeg.Duration(14 * time.Second), }, }, - wantIdleTimeout: time.Duration(10 * time.Second), - wantReadTimeout: time.Duration(12 * time.Second), - wantWriteTimeout: time.Duration(14 * time.Second), + expectedIdleTimeout: time.Duration(10 * time.Second), + expectedReadTimeout: time.Duration(12 * time.Second), + expectedWriteTimeout: time.Duration(14 * time.Second), }, { - desc: "using defaults", - globalConfig: configuration.GlobalConfiguration{}, - wantIdleTimeout: time.Duration(180 * time.Second), - wantReadTimeout: time.Duration(0 * time.Second), - wantWriteTimeout: time.Duration(0 * time.Second), + desc: "using defaults", + globalConfig: configuration.GlobalConfiguration{}, + expectedIdleTimeout: time.Duration(180 * time.Second), + expectedReadTimeout: time.Duration(0 * time.Second), + expectedWriteTimeout: time.Duration(0 * time.Second), }, { desc: "deprecated IdleTimeout configured", globalConfig: configuration.GlobalConfiguration{ IdleTimeout: flaeg.Duration(45 * time.Second), }, - wantIdleTimeout: time.Duration(45 * time.Second), - wantReadTimeout: time.Duration(0 * time.Second), - wantWriteTimeout: time.Duration(0 * time.Second), + expectedIdleTimeout: time.Duration(45 * time.Second), + expectedReadTimeout: time.Duration(0 * time.Second), + expectedWriteTimeout: time.Duration(0 * time.Second), }, { desc: "deprecated and new IdleTimeout configured", @@ -126,13 +125,13 @@ func TestPrepareServerTimeouts(t *testing.T) { IdleTimeout: flaeg.Duration(80 * time.Second), }, }, - wantIdleTimeout: time.Duration(45 * time.Second), - wantReadTimeout: time.Duration(0 * time.Second), - wantWriteTimeout: time.Duration(0 * time.Second), + expectedIdleTimeout: time.Duration(45 * time.Second), + expectedReadTimeout: time.Duration(0 * time.Second), + expectedWriteTimeout: time.Duration(0 * time.Second), }, } - for _, test := range tests { + for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { @@ -145,21 +144,13 @@ func TestPrepareServerTimeouts(t *testing.T) { } router := middlewares.NewHandlerSwitcher(mux.NewRouter()) - srv := NewServer(test.globalConfig, nil) - httpServer, _, err := srv.prepareServer(entryPointName, entryPoint, router, nil, nil) - if err != nil { - t.Fatalf("Unexpected error when preparing srv: %s", err) - } + srv := NewServer(test.globalConfig, nil, nil) + httpServer, _, err := srv.prepareServer(entryPointName, entryPoint, router, nil) + require.NoError(t, err, "Unexpected error when preparing srv") - if httpServer.IdleTimeout != test.wantIdleTimeout { - t.Errorf("Got %s as IdleTimeout, want %s", httpServer.IdleTimeout, test.wantIdleTimeout) - } - if httpServer.ReadTimeout != test.wantReadTimeout { - t.Errorf("Got %s as ReadTimeout, want %s", httpServer.ReadTimeout, test.wantReadTimeout) - } - if httpServer.WriteTimeout != test.wantWriteTimeout { - t.Errorf("Got %s as WriteTimeout, want %s", httpServer.WriteTimeout, test.wantWriteTimeout) - } + assert.Equal(t, test.expectedIdleTimeout, httpServer.IdleTimeout, "IdleTimeout") + assert.Equal(t, test.expectedReadTimeout, httpServer.ReadTimeout, "ReadTimeout") + assert.Equal(t, test.expectedWriteTimeout, httpServer.WriteTimeout, "WriteTimeout") }) } } @@ -286,7 +277,7 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c ProvidersThrottleDuration: flaeg.Duration(throttleDuration), } - server = NewServer(globalConfig, nil) + server = NewServer(globalConfig, nil, nil) go server.listenProviders(stop) return server, stop, invokeStopChan @@ -302,7 +293,7 @@ func TestThrottleProviderConfigReload(t *testing.T) { }() globalConfig := configuration.GlobalConfiguration{} - server := NewServer(globalConfig, nil) + server := NewServer(globalConfig, nil, nil) go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop) @@ -329,10 +320,7 @@ func TestThrottleProviderConfigReload(t *testing.T) { // after 50 milliseconds 5 new configs were published // with a throttle duration of 30 milliseconds this means, we should have received 2 new configs - wantPublishedConfigCount := 2 - if publishedConfigCount != wantPublishedConfigCount { - t.Errorf("%d times configs were published, want %d times", publishedConfigCount, wantPublishedConfigCount) - } + assert.Equal(t, 2, publishedConfigCount, "times configs were published") stopConsumeConfigs <- true @@ -351,7 +339,7 @@ func TestThrottleProviderConfigReload(t *testing.T) { } func TestServerMultipleFrontendRules(t *testing.T) { - cases := []struct { + testCases := []struct { expression string requestURL string expectedURL string @@ -393,7 +381,7 @@ func TestServerMultipleFrontendRules(t *testing.T) { }, } - for _, test := range cases { + for _, test := range testCases { test := test t.Run(test.expression, func(t *testing.T) { t.Parallel() @@ -420,9 +408,7 @@ func TestServerMultipleFrontendRules(t *testing.T) { server := new(Server) server.wireFrontendBackend(serverRoute, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() != test.expectedURL { - t.Fatalf("got URL %s, expected %s", r.URL.String(), test.expectedURL) - } + require.Equal(t, test.expectedURL, r.URL.String(), "URL") })) serverRoute.Route.GetHandler().ServeHTTP(nil, request) }) @@ -441,12 +427,14 @@ func TestServerLoadConfigHealthCheckOptions(t *testing.T) { for _, healthCheck := range healthChecks { t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) { globalConfig := configuration.GlobalConfiguration{ - EntryPoints: configuration.EntryPoints{ - "http": &configuration.EntryPoint{ + HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)}, + } + entryPoints := map[string]EntryPoint{ + "http": { + Configuration: &configuration.EntryPoint{ ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, }, }, - HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)}, } dynamicConfigs := types.Configurations{ @@ -482,19 +470,16 @@ func TestServerLoadConfigHealthCheckOptions(t *testing.T) { }, } - srv := NewServer(globalConfig, nil) - if _, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil { - t.Fatalf("got error: %s", err) - } + srv := NewServer(globalConfig, nil, entryPoints) - wantNumHealthCheckBackends := 0 + _, err := srv.loadConfig(dynamicConfigs, globalConfig) + require.NoError(t, err) + + expectedNumHealthCheckBackends := 0 if healthCheck != nil { - wantNumHealthCheckBackends = 1 - } - gotNumHealthCheckBackends := len(healthcheck.GetHealthCheck(testhelpers.NewCollectingHealthCheckMetrics()).Backends) - if gotNumHealthCheckBackends != wantNumHealthCheckBackends { - t.Errorf("got %d health check backends, want %d", gotNumHealthCheckBackends, wantNumHealthCheckBackends) + expectedNumHealthCheckBackends = 1 } + assert.Len(t, healthcheck.GetHealthCheck(testhelpers.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends") }) } } @@ -504,22 +489,22 @@ func TestServerParseHealthCheckOptions(t *testing.T) { lb := &testLoadBalancer{} globalInterval := 15 * time.Second - tests := []struct { - desc string - hc *types.HealthCheck - wantOpts *healthcheck.Options + testCases := []struct { + desc string + hc *types.HealthCheck + expectedOpts *healthcheck.Options }{ { - desc: "nil health check", - hc: nil, - wantOpts: nil, + desc: "nil health check", + hc: nil, + expectedOpts: nil, }, { desc: "empty path", hc: &types.HealthCheck{ Path: "", }, - wantOpts: nil, + expectedOpts: nil, }, { desc: "unparseable interval", @@ -527,7 +512,7 @@ func TestServerParseHealthCheckOptions(t *testing.T) { Path: "/path", Interval: "unparseable", }, - wantOpts: &healthcheck.Options{ + expectedOpts: &healthcheck.Options{ Path: "/path", Interval: globalInterval, LB: lb, @@ -539,7 +524,7 @@ func TestServerParseHealthCheckOptions(t *testing.T) { Path: "/path", Interval: "-42s", }, - wantOpts: &healthcheck.Options{ + expectedOpts: &healthcheck.Options{ Path: "/path", Interval: globalInterval, LB: lb, @@ -551,7 +536,7 @@ func TestServerParseHealthCheckOptions(t *testing.T) { Path: "/path", Interval: "5m", }, - wantOpts: &healthcheck.Options{ + expectedOpts: &healthcheck.Options{ Path: "/path", Interval: 5 * time.Minute, LB: lb, @@ -559,15 +544,13 @@ func TestServerParseHealthCheckOptions(t *testing.T) { }, } - for _, test := range tests { + for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() - gotOpts := parseHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)}) - if !reflect.DeepEqual(gotOpts, test.wantOpts) { - t.Errorf("got health check options %+v, want %+v", gotOpts, test.wantOpts) - } + opts := parseHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)}) + assert.Equal(t, test.expectedOpts, opts, "health check options") }) } } @@ -681,20 +664,19 @@ func TestServerLoadConfigEmptyBasicAuth(t *testing.T) { }, } - srv := NewServer(globalConfig, nil) - if _, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil { - t.Fatalf("got error: %s", err) - } + srv := NewServer(globalConfig, nil, nil) + _, err := srv.loadConfig(dynamicConfigs, globalConfig) + require.NoError(t, err) } func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) { globalConfig := configuration.GlobalConfiguration{ - EntryPoints: configuration.EntryPoints{ - "https": &configuration.EntryPoint{TLS: &tls.TLS{}}, - "http": &configuration.EntryPoint{}, - }, DefaultEntryPoints: []string{"http", "https"}, } + entryPoints := map[string]EntryPoint{ + "https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}}, + "http": {Configuration: &configuration.EntryPoint{}}, + } dynamicConfigs := types.Configurations{ "config": &types.Configuration{ @@ -709,7 +691,7 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) { }, } - srv := NewServer(globalConfig, nil) + srv := NewServer(globalConfig, nil, entryPoints) if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil { t.Fatalf("got error: %s", err) } else if mapEntryPoints["https"].certs.Get() == nil { @@ -721,11 +703,11 @@ func TestConfigureBackends(t *testing.T) { validMethod := "Drr" defaultMethod := "wrr" - tests := []struct { - desc string - lb *types.LoadBalancer - wantMethod string - wantStickiness *types.Stickiness + testCases := []struct { + desc string + lb *types.LoadBalancer + expectedMethod string + expectedStickiness *types.Stickiness }{ { desc: "valid load balancer method with sticky enabled", @@ -733,8 +715,8 @@ func TestConfigureBackends(t *testing.T) { Method: validMethod, Stickiness: &types.Stickiness{}, }, - wantMethod: validMethod, - wantStickiness: &types.Stickiness{}, + expectedMethod: validMethod, + expectedStickiness: &types.Stickiness{}, }, { desc: "valid load balancer method with sticky disabled", @@ -742,7 +724,7 @@ func TestConfigureBackends(t *testing.T) { Method: validMethod, Stickiness: nil, }, - wantMethod: validMethod, + expectedMethod: validMethod, }, { desc: "invalid load balancer method with sticky enabled", @@ -750,8 +732,8 @@ func TestConfigureBackends(t *testing.T) { Method: "Invalid", Stickiness: &types.Stickiness{}, }, - wantMethod: defaultMethod, - wantStickiness: &types.Stickiness{}, + expectedMethod: defaultMethod, + expectedStickiness: &types.Stickiness{}, }, { desc: "invalid load balancer method with sticky disabled", @@ -759,16 +741,16 @@ func TestConfigureBackends(t *testing.T) { Method: "Invalid", Stickiness: nil, }, - wantMethod: defaultMethod, + expectedMethod: defaultMethod, }, { - desc: "missing load balancer", - lb: nil, - wantMethod: defaultMethod, + desc: "missing load balancer", + lb: nil, + expectedMethod: defaultMethod, }, } - for _, test := range tests { + for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() @@ -780,22 +762,21 @@ func TestConfigureBackends(t *testing.T) { "backend": backend, }) - wantLB := types.LoadBalancer{ - Method: test.wantMethod, - Stickiness: test.wantStickiness, - } - if !reflect.DeepEqual(*backend.LoadBalancer, wantLB) { - t.Errorf("got backend load-balancer\n%v\nwant\n%v\n", spew.Sdump(backend.LoadBalancer), spew.Sdump(wantLB)) + expected := types.LoadBalancer{ + Method: test.expectedMethod, + Stickiness: test.expectedStickiness, } + + assert.Equal(t, expected, *backend.LoadBalancer) }) } } func TestServerEntryPointWhitelistConfig(t *testing.T) { - tests := []struct { - desc string - entrypoint *configuration.EntryPoint - wantMiddleware bool + testCases := []struct { + desc string + entrypoint *configuration.EntryPoint + expectMiddleware bool }{ { desc: "no whitelist middleware if no config on entrypoint", @@ -803,7 +784,7 @@ func TestServerEntryPointWhitelistConfig(t *testing.T) { Address: ":0", ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, }, - wantMiddleware: false, + expectMiddleware: false, }, { desc: "whitelist middleware should be added if configured on entrypoint", @@ -814,27 +795,29 @@ func TestServerEntryPointWhitelistConfig(t *testing.T) { }, ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, }, - wantMiddleware: true, + expectMiddleware: true, }, } - for _, test := range tests { + for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() srv := Server{ - globalConfiguration: configuration.GlobalConfiguration{ - EntryPoints: map[string]*configuration.EntryPoint{ - "test": test.entrypoint, + globalConfiguration: configuration.GlobalConfiguration{}, + metricsRegistry: metrics.NewVoidRegistry(), + entryPoints: map[string]EntryPoint{ + "test": { + Configuration: test.entrypoint, }, }, - metricsRegistry: metrics.NewVoidRegistry(), } - srv.serverEntryPoints = srv.buildEntryPoints(srv.globalConfiguration) + srv.serverEntryPoints = srv.buildEntryPoints() srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"]) handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni) + found := false for _, handler := range handler.Handlers() { if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) { @@ -842,12 +825,12 @@ func TestServerEntryPointWhitelistConfig(t *testing.T) { } } - if found && !test.wantMiddleware { - t.Errorf("ip whitelist middleware was installed even though it should not") + if found && !test.expectMiddleware { + t.Error("ip whitelist middleware was installed even though it should not") } - if !found && test.wantMiddleware { - t.Errorf("ip whitelist middleware was not installed even though it should have") + if !found && test.expectMiddleware { + t.Error("ip whitelist middleware was not installed even though it should have") } }) } @@ -858,9 +841,9 @@ func TestServerResponseEmptyBackend(t *testing.T) { const routeRule = "Path:" + requestPath testCases := []struct { - desc string - dynamicConfig func(testServerURL string) *types.Configuration - wantStatusCode int + desc string + dynamicConfig func(testServerURL string) *types.Configuration + expectedStatusCode int }{ { desc: "Ok", @@ -870,14 +853,14 @@ func TestServerResponseEmptyBackend(t *testing.T) { withBackend("backend", buildBackend(withServer("testServer", testServerURL))), ) }, - wantStatusCode: http.StatusOK, + expectedStatusCode: http.StatusOK, }, { desc: "No Frontend", dynamicConfig: func(testServerURL string) *types.Configuration { return buildDynamicConfig() }, - wantStatusCode: http.StatusNotFound, + expectedStatusCode: http.StatusNotFound, }, { desc: "Empty Backend LB-Drr", @@ -887,7 +870,7 @@ func TestServerResponseEmptyBackend(t *testing.T) { withBackend("backend", buildBackend(withLoadBalancer("Drr", false))), ) }, - wantStatusCode: http.StatusServiceUnavailable, + expectedStatusCode: http.StatusServiceUnavailable, }, { desc: "Empty Backend LB-Drr Sticky", @@ -897,7 +880,7 @@ func TestServerResponseEmptyBackend(t *testing.T) { withBackend("backend", buildBackend(withLoadBalancer("Drr", true))), ) }, - wantStatusCode: http.StatusServiceUnavailable, + expectedStatusCode: http.StatusServiceUnavailable, }, { desc: "Empty Backend LB-Wrr", @@ -907,7 +890,7 @@ func TestServerResponseEmptyBackend(t *testing.T) { withBackend("backend", buildBackend(withLoadBalancer("Wrr", false))), ) }, - wantStatusCode: http.StatusServiceUnavailable, + expectedStatusCode: http.StatusServiceUnavailable, }, { desc: "Empty Backend LB-Wrr Sticky", @@ -917,7 +900,7 @@ func TestServerResponseEmptyBackend(t *testing.T) { withBackend("backend", buildBackend(withLoadBalancer("Wrr", true))), ) }, - wantStatusCode: http.StatusServiceUnavailable, + expectedStatusCode: http.StatusServiceUnavailable, }, } @@ -932,14 +915,13 @@ func TestServerResponseEmptyBackend(t *testing.T) { })) defer testServer.Close() - globalConfig := configuration.GlobalConfiguration{ - EntryPoints: configuration.EntryPoints{ - "http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}, - }, + globalConfig := configuration.GlobalConfiguration{} + entryPointsConfig := map[string]EntryPoint{ + "http": {Configuration: &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}}, } dynamicConfigs := types.Configurations{"config": test.dynamicConfig(testServer.URL)} - srv := NewServer(globalConfig, nil) + srv := NewServer(globalConfig, nil, entryPointsConfig) entryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig) if err != nil { t.Fatalf("error loading config: %s", err) @@ -950,20 +932,17 @@ func TestServerResponseEmptyBackend(t *testing.T) { entryPoints["http"].httpRouter.ServeHTTP(responseRecorder, request) - if responseRecorder.Result().StatusCode != test.wantStatusCode { - t.Errorf("got status code %d, want %d", responseRecorder.Result().StatusCode, test.wantStatusCode) - } + assert.Equal(t, test.expectedStatusCode, responseRecorder.Result().StatusCode, "status code") }) } } func TestBuildRedirectHandler(t *testing.T) { srv := Server{ - globalConfiguration: configuration.GlobalConfiguration{ - EntryPoints: configuration.EntryPoints{ - "http": &configuration.EntryPoint{Address: ":80"}, - "https": &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}, - }, + globalConfiguration: configuration.GlobalConfiguration{}, + entryPoints: map[string]EntryPoint{ + "http": {Configuration: &configuration.EntryPoint{Address: ":80"}}, + "https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}}, }, } @@ -1073,7 +1052,7 @@ func (c mockContext) Value(key interface{}) interface{} { } func TestNewServerWithResponseModifiers(t *testing.T) { - cases := []struct { + testCases := []struct { desc string headerMiddleware *middlewares.HeaderStruct secureMiddleware *secure.Secure @@ -1138,7 +1117,7 @@ func TestNewServerWithResponseModifiers(t *testing.T) { }, } - for _, test := range cases { + for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() diff --git a/types/internal_router.go b/types/internal_router.go new file mode 100644 index 000000000..62bc1f51c --- /dev/null +++ b/types/internal_router.go @@ -0,0 +1,10 @@ +package types + +import ( + "github.com/containous/mux" +) + +// InternalRouter router used by server to register internal routes (/api, /ping ...) +type InternalRouter interface { + AddRoutes(systemRouter *mux.Router) +}