diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index f87e4a2ca..f24ba8766 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -19,12 +19,18 @@ import ( var singleton *HealthCheck var once sync.Once -// GetHealthCheck returns the health check which is guaranteed to be a singleton. -func GetHealthCheck(metrics metricsRegistry) *HealthCheck { - once.Do(func() { - singleton = newHealthCheck(metrics) - }) - return singleton +// BalancerHandler includes functionality for load-balancing management. +type BalancerHandler interface { + ServeHTTP(w http.ResponseWriter, req *http.Request) + Servers() []*url.URL + RemoveServer(u *url.URL) error + UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error +} + +// metricsRegistry is a local interface in the health check package, exposing only the required metrics +// necessary for the health check package. This makes it easier for the tests. +type metricsRegistry interface { + BackendServerUpGauge() metrics.Gauge } // Options are the public health check options. @@ -36,59 +42,59 @@ type Options struct { Port int Transport http.RoundTripper Interval time.Duration - LB LoadBalancer + LB BalancerHandler } func (opt Options) String() string { return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Port: %d Interval: %s]", opt.Hostname, opt.Headers, opt.Path, opt.Port, opt.Interval) } -// BackendHealthCheck HealthCheck configuration for a backend -type BackendHealthCheck struct { +// BackendConfig HealthCheck configuration for a backend +type BackendConfig struct { Options name string disabledURLs []*url.URL requestTimeout time.Duration } +func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) { + u := &url.URL{} + *u = *serverURL + + if len(b.Scheme) > 0 { + u.Scheme = b.Scheme + } + + if b.Port != 0 { + u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port)) + } + + u.Path += b.Path + + return http.NewRequest(http.MethodGet, u.String(), nil) +} + +// this function adds additional http headers and hostname to http.request +func (b *BackendConfig) addHeadersAndHost(req *http.Request) *http.Request { + if b.Options.Hostname != "" { + req.Host = b.Options.Hostname + } + + for k, v := range b.Options.Headers { + req.Header.Set(k, v) + } + return req +} + // HealthCheck struct type HealthCheck struct { - Backends map[string]*BackendHealthCheck + Backends map[string]*BackendConfig metrics metricsRegistry cancel context.CancelFunc } -// LoadBalancer includes functionality for load-balancing management. -type LoadBalancer interface { - RemoveServer(u *url.URL) error - UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error - Servers() []*url.URL -} - -func newHealthCheck(metrics metricsRegistry) *HealthCheck { - return &HealthCheck{ - Backends: make(map[string]*BackendHealthCheck), - metrics: metrics, - } -} - -// metricsRegistry is a local interface in the health check package, exposing only the required metrics -// necessary for the health check package. This makes it easier for the tests. -type metricsRegistry interface { - BackendServerUpGauge() metrics.Gauge -} - -// NewBackendHealthCheck Instantiate a new BackendHealthCheck -func NewBackendHealthCheck(options Options, backendName string) *BackendHealthCheck { - return &BackendHealthCheck{ - Options: options, - name: backendName, - requestTimeout: 5 * time.Second, - } -} - // SetBackendsConfiguration set backends configuration -func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendHealthCheck) { +func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) { hc.Backends = backends if hc.cancel != nil { hc.cancel() @@ -104,7 +110,7 @@ func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backe } } -func (hc *HealthCheck) execute(ctx context.Context, backend *BackendHealthCheck) { +func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) { log.Debugf("Initial health check for backend: %q", backend.name) hc.checkBackend(backend) ticker := time.NewTicker(backend.Interval) @@ -121,7 +127,7 @@ func (hc *HealthCheck) execute(ctx context.Context, backend *BackendHealthCheck) } } -func (hc *HealthCheck) checkBackend(backend *BackendHealthCheck) { +func (hc *HealthCheck) checkBackend(backend *BackendConfig) { enabledURLs := backend.LB.Servers() var newDisabledURLs []*url.URL for _, url := range backend.disabledURLs { @@ -152,38 +158,33 @@ func (hc *HealthCheck) checkBackend(backend *BackendHealthCheck) { } } -func (b *BackendHealthCheck) newRequest(serverURL *url.URL) (*http.Request, error) { - u := &url.URL{} - *u = *serverURL - - if len(b.Scheme) > 0 { - u.Scheme = b.Scheme - } - - if b.Port != 0 { - u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port)) - } - - u.Path += b.Path - - return http.NewRequest(http.MethodGet, u.String(), nil) +// GetHealthCheck returns the health check which is guaranteed to be a singleton. +func GetHealthCheck(metrics metricsRegistry) *HealthCheck { + once.Do(func() { + singleton = newHealthCheck(metrics) + }) + return singleton } -// this function adds additional http headers and hostname to http.request -func (b *BackendHealthCheck) addHeadersAndHost(req *http.Request) *http.Request { - if b.Options.Hostname != "" { - req.Host = b.Options.Hostname +func newHealthCheck(metrics metricsRegistry) *HealthCheck { + return &HealthCheck{ + Backends: make(map[string]*BackendConfig), + metrics: metrics, } +} - for k, v := range b.Options.Headers { - req.Header.Set(k, v) +// NewBackendConfig Instantiate a new BackendConfig +func NewBackendConfig(options Options, backendName string) *BackendConfig { + return &BackendConfig{ + Options: options, + name: backendName, + requestTimeout: 5 * time.Second, } - return req } // checkHealth returns a nil error in case it was successful and otherwise // a non-nil error with a meaningful description why the health check failed. -func checkHealth(serverURL *url.URL, backend *BackendHealthCheck) error { +func checkHealth(serverURL *url.URL, backend *BackendConfig) error { req, err := backend.newRequest(serverURL) if err != nil { return fmt.Errorf("failed to create HTTP request: %s", err) diff --git a/healthcheck/healthcheck_test.go b/healthcheck/healthcheck_test.go index 004ba0c7b..131070e6d 100644 --- a/healthcheck/healthcheck_test.go +++ b/healthcheck/healthcheck_test.go @@ -102,7 +102,7 @@ func TestSetBackendsConfiguration(t *testing.T) { defer ts.Close() lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}} - backend := NewBackendHealthCheck(Options{ + backend := NewBackendConfig(Options{ Path: "/path", Interval: healthCheckInterval, LB: lb, @@ -117,7 +117,7 @@ func TestSetBackendsConfiguration(t *testing.T) { collectingMetrics := testhelpers.NewCollectingHealthCheckMetrics() check := HealthCheck{ - Backends: make(map[string]*BackendHealthCheck), + Backends: make(map[string]*BackendConfig), metrics: collectingMetrics, } @@ -209,7 +209,7 @@ func TestNewRequest(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - backend := NewBackendHealthCheck(test.options, "backendName") + backend := NewBackendConfig(test.options, "backendName") u, err := url.Parse(test.serverURL) require.NoError(t, err) @@ -279,7 +279,7 @@ func TestAddHeadersAndHost(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - backend := NewBackendHealthCheck(test.options, "backendName") + backend := NewBackendConfig(test.options, "backendName") u, err := url.Parse(test.serverURL) require.NoError(t, err) @@ -305,6 +305,10 @@ type testLoadBalancer struct { servers []*url.URL } +func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // noop +} + func (lb *testLoadBalancer) RemoveServer(u *url.URL) error { lb.Lock() defer lb.Unlock() diff --git a/integration/access_log_test.go b/integration/access_log_test.go index 9e950e526..eaaca7ede 100644 --- a/integration/access_log_test.go +++ b/integration/access_log_test.go @@ -102,7 +102,7 @@ func (s *AccessLogSuite) TestAccessLogAuthFrontend(c *check.C) { formatOnly: false, code: "401", user: "-", - frontendName: "Auth for frontend-Host-frontend-auth-docker-local", + frontendName: "Basic Auth for frontend-Host-frontend-auth-docker-local", backendURL: "/", }, } @@ -354,7 +354,7 @@ func (s *AccessLogSuite) TestAccessLogEntrypointRedirect(c *check.C) { formatOnly: false, code: "302", user: "-", - frontendName: "entrypoint redirect for frontend-", + frontendName: "entrypoint redirect for httpRedirect", backendURL: "/", }, { diff --git a/middlewares/auth/authenticator.go b/middlewares/auth/authenticator.go index f627f53fe..20e29635c 100644 --- a/middlewares/auth/authenticator.go +++ b/middlewares/auth/authenticator.go @@ -31,38 +31,43 @@ func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing if authConfig == nil { return nil, fmt.Errorf("error creating Authenticator: auth is nil") } + var err error - authenticator := Authenticator{} - tracingAuthenticator := tracingAuthenticator{} + authenticator := &Authenticator{} + tracingAuth := tracingAuthenticator{} + if authConfig.Basic != nil { authenticator.users, err = parserBasicUsers(authConfig.Basic) if err != nil { return nil, err } + basicAuth := goauth.NewBasicAuthenticator("traefik", authenticator.secretBasic) - tracingAuthenticator.handler = createAuthBasicHandler(basicAuth, authConfig) - tracingAuthenticator.name = "Auth Basic" - tracingAuthenticator.clientSpanKind = false + tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig) + tracingAuth.name = "Auth Basic" + tracingAuth.clientSpanKind = false } else if authConfig.Digest != nil { authenticator.users, err = parserDigestUsers(authConfig.Digest) if err != nil { return nil, err } + digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest) - tracingAuthenticator.handler = createAuthDigestHandler(digestAuth, authConfig) - tracingAuthenticator.name = "Auth Digest" - tracingAuthenticator.clientSpanKind = false + tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig) + tracingAuth.name = "Auth Digest" + tracingAuth.clientSpanKind = false } else if authConfig.Forward != nil { - tracingAuthenticator.handler = createAuthForwardHandler(authConfig) - tracingAuthenticator.name = "Auth Forward" - tracingAuthenticator.clientSpanKind = true + tracingAuth.handler = createAuthForwardHandler(authConfig) + tracingAuth.name = "Auth Forward" + tracingAuth.clientSpanKind = true } + if tracingMiddleware != nil { - authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuthenticator.name, tracingAuthenticator.handler, tracingAuthenticator.clientSpanKind) + authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind) } else { - authenticator.handler = tracingAuthenticator.handler + authenticator.handler = tracingAuth.handler } - return &authenticator, nil + return authenticator, nil } func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc { diff --git a/middlewares/cbreaker.go b/middlewares/cbreaker.go index 178dae855..e06783456 100644 --- a/middlewares/cbreaker.go +++ b/middlewares/cbreaker.go @@ -23,14 +23,14 @@ func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker // NewCircuitBreakerOptions returns a new CircuitBreakerOption func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption { - return cbreaker.Fallback( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tracing.LogEventf(r, "blocked by circuitbreaker (%q)", expression) - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))) - })) + return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression) + + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))) + })) } -func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { +func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) { cb.circuitBreaker.ServeHTTP(rw, r) } diff --git a/middlewares/empty_backend_handler.go b/middlewares/empty_backend_handler.go index dfdd216e3..f775b4663 100644 --- a/middlewares/empty_backend_handler.go +++ b/middlewares/empty_backend_handler.go @@ -10,19 +10,18 @@ import ( // has at least one active Server in respect to the healthchecks and if this // is not the case, it will stop the middleware chain and respond with 503. type EmptyBackendHandler struct { - lb healthcheck.LoadBalancer - next http.Handler + next healthcheck.BalancerHandler } // NewEmptyBackendHandler creates a new EmptyBackendHandler instance. -func NewEmptyBackendHandler(lb healthcheck.LoadBalancer, next http.Handler) *EmptyBackendHandler { - return &EmptyBackendHandler{lb: lb, next: next} +func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler { + return &EmptyBackendHandler{next: lb} } // ServeHTTP responds with 503 when there is no active Server and otherwise // invokes the next handler in the middleware chain. func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - if len(h.lb.Servers()) == 0 { + if len(h.next.Servers()) == 0 { rw.WriteHeader(http.StatusServiceUnavailable) rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))) } else { diff --git a/middlewares/empty_backend_handler_test.go b/middlewares/empty_backend_handler_test.go index b77232d8c..9e2e36c38 100644 --- a/middlewares/empty_backend_handler_test.go +++ b/middlewares/empty_backend_handler_test.go @@ -32,10 +32,7 @@ func TestEmptyBackendHandler(t *testing.T) { t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) { t.Parallel() - nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer}, nextHandler) + handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer}) recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) @@ -53,12 +50,8 @@ type healthCheckLoadBalancer struct { amountServer int } -func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error { - return nil -} - -func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - return nil +func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) } func (lb *healthCheckLoadBalancer) Servers() []*url.URL { @@ -68,3 +61,23 @@ func (lb *healthCheckLoadBalancer) Servers() []*url.URL { } return servers } + +func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error { + return nil +} + +func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { + return nil +} + +func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) { + return 0, false +} + +func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) { + return nil, nil +} + +func (lb *healthCheckLoadBalancer) Next() http.Handler { + return nil +} diff --git a/middlewares/tracing/wrapper.go b/middlewares/tracing/wrapper.go index 11f3d6a5f..8e9c566c1 100644 --- a/middlewares/tracing/wrapper.go +++ b/middlewares/tracing/wrapper.go @@ -6,20 +6,6 @@ import ( "github.com/urfave/negroni" ) -// NegroniHandlerWrapper is used to wrap negroni handler middleware -type NegroniHandlerWrapper struct { - name string - next negroni.Handler - clientSpanKind bool -} - -// HTTPHandlerWrapper is used to wrap http handler middleware -type HTTPHandlerWrapper struct { - name string - handler http.Handler - clientSpanKind bool -} - // NewNegroniHandlerWrapper return a negroni.Handler struct func (t *Tracing) NewNegroniHandlerWrapper(name string, handler negroni.Handler, clientSpanKind bool) negroni.Handler { if t.IsEnabled() && handler != nil { @@ -44,6 +30,13 @@ func (t *Tracing) NewHTTPHandlerWrapper(name string, handler http.Handler, clien return handler } +// NegroniHandlerWrapper is used to wrap negroni handler middleware +type NegroniHandlerWrapper struct { + name string + next negroni.Handler + clientSpanKind bool +} + func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { var finish func() _, r, finish = StartSpan(r, t.name, t.clientSpanKind) @@ -54,6 +47,13 @@ func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques } } +// HTTPHandlerWrapper is used to wrap http handler middleware +type HTTPHandlerWrapper struct { + name string + handler http.Handler + clientSpanKind bool +} + func (t *HTTPHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { var finish func() _, r, finish = StartSpan(r, t.name, t.clientSpanKind) diff --git a/server/adapters.go b/server/adapters.go deleted file mode 100644 index 453b8187c..000000000 --- a/server/adapters.go +++ /dev/null @@ -1,9 +0,0 @@ -package server - -import ( - "net/http" -) - -func notFoundHandler(w http.ResponseWriter, r *http.Request) { - http.NotFound(w, r) -} diff --git a/server/bufferpool.go b/server/bufferpool.go index 157ea2ad7..6cd194830 100644 --- a/server/bufferpool.go +++ b/server/bufferpool.go @@ -2,7 +2,7 @@ package server import "sync" -const bufferPoolSize int = 32 * 1024 +const bufferPoolSize = 32 * 1024 func newBufferPool() *bufferPool { return &bufferPool{ diff --git a/server/server.go b/server/server.go index e80c33d1f..340ee650d 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" - "errors" "fmt" "io/ioutil" stdlog "log" @@ -16,7 +15,6 @@ import ( "os" "os/signal" "reflect" - "sort" "strings" "sync" "time" @@ -27,34 +25,18 @@ import ( "github.com/containous/traefik/configuration" "github.com/containous/traefik/configuration/router" "github.com/containous/traefik/h2c" - "github.com/containous/traefik/healthcheck" "github.com/containous/traefik/log" "github.com/containous/traefik/metrics" "github.com/containous/traefik/middlewares" "github.com/containous/traefik/middlewares/accesslog" - mauth "github.com/containous/traefik/middlewares/auth" - "github.com/containous/traefik/middlewares/errorpages" - "github.com/containous/traefik/middlewares/redirect" "github.com/containous/traefik/middlewares/tracing" "github.com/containous/traefik/provider" - "github.com/containous/traefik/rules" "github.com/containous/traefik/safe" - "github.com/containous/traefik/server/cookie" traefiktls "github.com/containous/traefik/tls" "github.com/containous/traefik/types" "github.com/containous/traefik/whitelist" - "github.com/eapache/channels" "github.com/sirupsen/logrus" - thoas_stats "github.com/thoas/stats" - "github.com/unrolled/secure" "github.com/urfave/negroni" - "github.com/vulcand/oxy/buffer" - "github.com/vulcand/oxy/connlimit" - "github.com/vulcand/oxy/forward" - "github.com/vulcand/oxy/ratelimit" - "github.com/vulcand/oxy/roundrobin" - "github.com/vulcand/oxy/utils" - "golang.org/x/net/http2" ) var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0) @@ -101,10 +83,11 @@ type serverEntryPoint struct { // NewServer returns an initialized Server. func NewServer(globalConfiguration configuration.GlobalConfiguration, provider provider.Provider, entrypoints map[string]EntryPoint) *Server { - server := new(Server) + server := &Server{} server.entryPoints = entrypoints server.provider = provider + server.globalConfiguration = globalConfiguration server.serverEntryPoints = make(map[string]*serverEntryPoint) server.configurationChan = make(chan types.ConfigMessage, 100) server.configurationValidatedChan = make(chan types.ConfigMessage, 100) @@ -114,7 +97,7 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p currentConfigurations := make(types.Configurations) server.currentConfigurations.Set(currentConfigurations) server.providerConfigUpdateMap = make(map[string]chan types.ConfigMessage) - server.globalConfiguration = globalConfiguration + if server.globalConfiguration.API != nil { server.globalConfiguration.API.CurrentConfigurations = &server.currentConfigurations } @@ -122,10 +105,16 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p server.bufferPool = newBufferPool() server.routinesPool = safe.NewPool(context.Background()) - server.defaultForwardingRoundTripper = createHTTPTransport(globalConfiguration) + + transport, err := createHTTPTransport(globalConfiguration) + if err != nil { + log.Errorf("failed to create HTTP transport: %v", err) + } + + server.defaultForwardingRoundTripper = transport server.tracingMiddleware = globalConfiguration.Tracing - if globalConfiguration.Tracing != nil && globalConfiguration.Tracing.Backend != "" { + if server.tracingMiddleware != nil && server.tracingMiddleware.Backend != "" { server.tracingMiddleware.Setup() } @@ -150,79 +139,6 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p return server } -type h2cTransportWrapper struct { - *http2.Transport -} - -func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) { - req.URL.Scheme = "http" - return t.Transport.RoundTrip(req) -} - -// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings. -// For the settings that can't be configured in Traefik it uses the default http.Transport settings. -// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost -// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing -// behaviour and backwards compatibility issues. -func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) *http.Transport { - dialer := &net.Dialer{ - Timeout: configuration.DefaultDialTimeout, - KeepAlive: 30 * time.Second, - DualStack: true, - } - if globalConfiguration.ForwardingTimeouts != nil { - dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout) - } - - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: dialer.DialContext, - MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - - transport.RegisterProtocol("h2c", &h2cTransportWrapper{ - Transport: &http2.Transport{ - DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { - return net.Dial(netw, addr) - }, - AllowHTTP: true, - }, - }) - - if globalConfiguration.ForwardingTimeouts != nil { - transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout) - } - if globalConfiguration.InsecureSkipVerify { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - if len(globalConfiguration.RootCAs) > 0 { - transport.TLSClientConfig = &tls.Config{ - RootCAs: createRootCACertPool(globalConfiguration.RootCAs), - } - } - http2.ConfigureTransport(transport) - - return transport -} - -func createRootCACertPool(rootCAs traefiktls.RootCAs) *x509.CertPool { - roots := x509.NewCertPool() - - for _, cert := range rootCAs { - certContent, err := cert.Read() - if err != nil { - log.Error("Error while read RootCAs", err) - continue - } - roots.AppendCertsFromPEM(certContent) - } - - return roots -} - // Start starts the server. func (s *Server) Start() { s.startHTTPServers() @@ -272,7 +188,10 @@ func (s *Server) Stop() { log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName) if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil { log.Debugf("Wait is over due to: %s", err) - serverEntryPoint.httpServer.Close() + err = serverEntryPoint.httpServer.Close() + if err != nil { + log.Error(err) + } } cancel() log.Debugf("Entrypoint %s closed", serverEntryPointName) @@ -322,7 +241,7 @@ func (s *Server) stopLeadership() { } func (s *Server) startHTTPServers() { - s.serverEntryPoints = s.buildEntryPoints() + s.serverEntryPoints = s.buildServerEntryPoints() for newServerEntryPointName, newServerEntryPoint := range s.serverEntryPoints { serverEntryPoint := s.setupServerEntryPoint(newServerEntryPointName, newServerEntryPoint) @@ -330,68 +249,6 @@ func (s *Server) startHTTPServers() { } } -func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint { - serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()} - - if s.tracingMiddleware.IsEnabled() { - serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(newServerEntryPointName)) - } - - if s.accessLoggerMiddleware != nil { - serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware) - } - - if s.metricsRegistry.IsEnabled() { - serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, newServerEntryPointName)) - } - - if s.globalConfiguration.API != nil { - if s.globalConfiguration.API.Stats == nil { - s.globalConfiguration.API.Stats = thoas_stats.New() - } - serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.Stats) - if s.globalConfiguration.API.Statistics != nil { - if s.globalConfiguration.API.StatsRecorder == nil { - s.globalConfiguration.API.StatsRecorder = middlewares.NewStatsRecorder(s.globalConfiguration.API.Statistics.RecentErrors) - } - serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.StatsRecorder) - } - } - - if s.entryPoints[newServerEntryPointName].Configuration.Auth != nil { - authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[newServerEntryPointName].Configuration.Auth, s.tracingMiddleware) - if err != nil { - log.Fatal("Error starting server: ", err) - } - serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", newServerEntryPointName))) - } - - if s.entryPoints[newServerEntryPointName].Configuration.Compress { - serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{}) - } - - ipWhitelistMiddleware, err := buildIPWhiteLister( - s.entryPoints[newServerEntryPointName].Configuration.WhiteList, - s.entryPoints[newServerEntryPointName].Configuration.WhitelistSourceRange) - if err != nil { - log.Fatal("Error starting server: ", err) - } - if ipWhitelistMiddleware != nil { - serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", newServerEntryPointName))) - } - - newSrv, listener, err := s.prepareServer(newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter, serverMiddlewares) - if err != nil { - log.Fatal("Error preparing server: ", err) - } - - serverEntryPoint := s.serverEntryPoints[newServerEntryPointName] - serverEntryPoint.httpServer = newSrv - serverEntryPoint.listener = listener - - return serverEntryPoint -} - func (s *Server) listenProviders(stop chan bool) { for { select { @@ -406,119 +263,6 @@ func (s *Server) listenProviders(stop chan bool) { } } -func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) { - providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration) - s.defaultConfigurationValues(configMsg.Configuration) - currentConfigurations := s.currentConfigurations.Get().(types.Configurations) - jsonConf, _ := json.Marshal(configMsg.Configuration) - log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf)) - if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLS == nil { - log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName) - } else if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) { - log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName) - } else { - providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName] - if !ok { - providerConfigUpdateCh = make(chan types.ConfigMessage) - s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh - s.routinesPool.Go(func(stop chan bool) { - s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop) - }) - } - providerConfigUpdateCh <- configMsg - } -} - -// throttleProviderConfigReload throttles the configuration reload speed for a single provider. -// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration. -// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing, -// it will publish the last of the newly received configurations. -func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- types.ConfigMessage, in <-chan types.ConfigMessage, stop chan bool) { - ring := channels.NewRingChannel(1) - defer ring.Close() - - s.routinesPool.Go(func(stop chan bool) { - for { - select { - case <-stop: - return - case nextConfig := <-ring.Out(): - publish <- nextConfig.(types.ConfigMessage) - time.Sleep(throttle) - } - } - }) - - for { - select { - case <-stop: - return - case nextConfig := <-in: - ring.In() <- nextConfig - } - } -} - -func (s *Server) defaultConfigurationValues(configuration *types.Configuration) { - if configuration == nil || configuration.Frontends == nil { - return - } - configureFrontends(configuration.Frontends, s.globalConfiguration.DefaultEntryPoints) - configureBackends(configuration.Backends) -} - -func (s *Server) listenConfigurations(stop chan bool) { - for { - select { - case <-stop: - return - case configMsg, ok := <-s.configurationValidatedChan: - if !ok || configMsg.Configuration == nil { - return - } - s.loadConfiguration(configMsg) - } - } -} - -// loadConfiguration manages dynamically frontends, backends and TLS configurations -func (s *Server) loadConfiguration(configMsg types.ConfigMessage) { - currentConfigurations := s.currentConfigurations.Get().(types.Configurations) - - // Copy configurations to new map so we don't change current if LoadConfig fails - newConfigurations := make(types.Configurations) - for k, v := range currentConfigurations { - newConfigurations[k] = v - } - newConfigurations[configMsg.ProviderName] = configMsg.Configuration - - s.metricsRegistry.ConfigReloadsCounter().Add(1) - newServerEntryPoints, err := s.loadConfig(newConfigurations, s.globalConfiguration) - if err == nil { - s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix())) - for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints { - s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler()) - if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil { - if newServerEntryPoint.certs.Get() != nil { - log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName) - } - } else { - s.serverEntryPoints[newServerEntryPointName].certs.Set(newServerEntryPoint.certs.Get()) - } - log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr) - } - s.currentConfigurations.Set(newConfigurations) - for _, listener := range s.configurationListeners { - listener(*configMsg.Configuration) - } - s.postLoadConfiguration() - } else { - s.metricsRegistry.ConfigReloadsFailureCounter().Add(1) - s.metricsRegistry.LastConfigReloadFailureGauge().Set(float64(time.Now().Unix())) - log.Error("Error loading new configuration, aborted ", err) - } -} - // AddListener adds a new listener function used when new configuration is provided func (s *Server) AddListener(listener func(types.Configuration)) { if s.configurationListeners == nil { @@ -527,20 +271,6 @@ func (s *Server) AddListener(listener func(types.Configuration)) { s.configurationListeners = append(s.configurationListeners, listener) } -// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically -func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) (map[string]map[string]*tls.Certificate, error) { - newEPCertificates := make(map[string]map[string]*tls.Certificate) - // Get all certificates - for _, configuration := range configurations { - if configuration.TLS != nil && len(configuration.TLS) > 0 { - if err := traefiktls.SortTLSPerEntryPoints(configuration.TLS, newEPCertificates, defaultEntryPoints); err != nil { - return nil, err - } - } - } - return newEPCertificates, nil -} - // getCertificate allows to customize tlsConfig.GetCertificate behaviour to get the certificates inserted dynamically func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { domainToCheck := types.CanonicalDomain(clientHello.ServerName) @@ -560,47 +290,6 @@ func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tl return nil, nil } -func (s *Server) postLoadConfiguration() { - if s.metricsRegistry.IsEnabled() { - activeConfig := s.currentConfigurations.Get().(types.Configurations) - metrics.OnConfigurationUpdate(activeConfig) - } - - if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() { - return - } - - if s.globalConfiguration.ACME.OnHostRule { - currentConfigurations := s.currentConfigurations.Get().(types.Configurations) - for _, config := range currentConfigurations { - for _, frontend := range config.Frontends { - - // check if one of the frontend entrypoints is configured with TLS - // and is configured with ACME - acmeEnabled := false - for _, entryPoint := range frontend.EntryPoints { - if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil { - acmeEnabled = true - break - } - } - - if acmeEnabled { - for _, route := range frontend.Routes { - rules := rules.Rules{} - domains, err := rules.ParseDomains(route.Rule) - if err != nil { - log.Errorf("Error parsing domains: %v", err) - } else { - s.globalConfiguration.ACME.LoadCertificateForDomains(domains) - } - } - } - } - } - } -} - func (s *Server) startProvider() { // start providers providerType := reflect.TypeOf(s.provider) @@ -618,38 +307,6 @@ func (s *Server) startProvider() { }) } -func createClientTLSConfig(entryPointName string, tlsOption *traefiktls.TLS) (*tls.Config, error) { - if tlsOption == nil { - return nil, errors.New("no TLS provided") - } - - config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName) - if err != nil { - return nil, err - } - - if len(tlsOption.ClientCAFiles) > 0 { - log.Warnf("Deprecated configuration found during client TLS configuration creation: %s. Please use %s (which allows to make the CA Files optional).", "tls.ClientCAFiles", "tls.ClientCA.files") - tlsOption.ClientCA.Files = tlsOption.ClientCAFiles - tlsOption.ClientCA.Optional = false - } - if len(tlsOption.ClientCA.Files) > 0 { - pool := x509.NewCertPool() - for _, caFile := range tlsOption.ClientCA.Files { - data, err := ioutil.ReadFile(caFile) - if err != nil { - return nil, err - } - if !pool.AppendCertsFromPEM(data) { - return nil, errors.New("invalid certificate(s) in " + caFile) - } - } - config.RootCAs = pool - } - config.BuildNameToCertificate() - return config, nil -} - // creates a TLS config that allows terminating HTTPS for multiple domains using SNI func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TLS, router *middlewares.HandlerSwitcher) (*tls.Config, error) { if tlsOption == nil { @@ -679,7 +336,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL } ok := pool.AppendCertsFromPEM(data) if !ok { - return nil, errors.New("invalid certificate(s) in " + caFile) + return nil, fmt.Errorf("invalid certificate(s) in %s", caFile) } } config.ClientCAs = pool @@ -709,9 +366,11 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL } else { config.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate } + if len(config.Certificates) == 0 { - return nil, errors.New("No certificates found for TLS entrypoint " + entryPointName) + return nil, fmt.Errorf("no certificates found for TLS entrypoint %s", entryPointName) } + // BuildNameToCertificate parses the CommonName and SubjectAlternateName fields // in each certificate and populates the config.NameToCertificate map. config.BuildNameToCertificate() @@ -735,26 +394,47 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL config.CipherSuites = append(config.CipherSuites, cipherConst) } else { // CipherSuite listed in the toml does not exist in our listed - return nil, errors.New("Invalid CipherSuite: " + cipher) + return nil, fmt.Errorf("invalid CipherSuite: %s", cipher) } } } + return config, nil } func (s *Server) startServer(serverEntryPoint *serverEntryPoint) { log.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr) + var err error if serverEntryPoint.httpServer.TLSConfig != nil { err = serverEntryPoint.httpServer.ServeTLS(serverEntryPoint.listener, "", "") } else { err = serverEntryPoint.httpServer.Serve(serverEntryPoint.listener) } + if err != http.ErrServerClosed { log.Error("Error creating server: ", err) } } +func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint { + serverMiddlewares, err := s.buildServerEntryPointMiddlewares(newServerEntryPointName, newServerEntryPoint) + if err != nil { + log.Fatal("Error preparing server: ", err) + } + + newSrv, listener, err := s.prepareServer(newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter, serverMiddlewares) + if err != nil { + log.Fatal("Error preparing server: ", err) + } + + serverEntryPoint := s.serverEntryPoints[newServerEntryPointName] + serverEntryPoint.httpServer = newSrv + serverEntryPoint.listener = listener + + return serverEntryPoint +} + func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares []negroni.Handler) (*h2c.Server, net.Listener, error) { readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(s.globalConfiguration) log.Infof("Preparing server %s %+v with readTimeout=%s writeTimeout=%s idleTimeout=%s", entryPointName, entryPoint, readTimeout, writeTimeout, idleTimeout) @@ -771,32 +451,18 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration. tlsConfig, err := s.createTLSConfig(entryPointName, entryPoint.TLS, router) if err != nil { - log.Errorf("Error creating TLS config: %s", err) - return nil, nil, err + return nil, nil, fmt.Errorf("error creating TLS config: %v", err) } listener, err := net.Listen("tcp", entryPoint.Address) if err != nil { - log.Error("Error opening listener ", err) - return nil, nil, err + return nil, nil, fmt.Errorf("error opening listener: %v", err) } if entryPoint.ProxyProtocol != nil { - IPs, err := whitelist.NewIP(entryPoint.ProxyProtocol.TrustedIPs, entryPoint.ProxyProtocol.Insecure, false) + listener, err = buildProxyProtocolListener(entryPoint, listener) if err != nil { - return nil, nil, fmt.Errorf("error creating whitelist: %s", err) - } - log.Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs) - listener = &proxyproto.Listener{ - Listener: listener, - SourceCheck: func(addr net.Addr) (bool, error) { - ip, ok := addr.(*net.TCPAddr) - if !ok { - return false, fmt.Errorf("type error %v", addr) - } - - return IPs.ContainsIP(ip.IP), nil - }, + return nil, nil, err } } @@ -815,6 +481,27 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration. nil } +func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener net.Listener) (net.Listener, error) { + IPs, err := whitelist.NewIP(entryPoint.ProxyProtocol.TrustedIPs, entryPoint.ProxyProtocol.Insecure, false) + if err != nil { + return nil, fmt.Errorf("error creating whitelist: %s", err) + } + + log.Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs) + + return &proxyproto.Listener{ + Listener: listener, + SourceCheck: func(addr net.Addr) (bool, error) { + ip, ok := addr.(*net.TCPAddr) + if !ok { + return false, fmt.Errorf("type error %v", addr) + } + + return IPs.ContainsIP(ip.IP), nil + }, + }, nil +} + func (s *Server) buildInternalRouter(entryPointName string) *mux.Router { internalMuxRouter := mux.NewRouter() internalMuxRouter.StrictSlash(true) @@ -857,607 +544,6 @@ func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTi return readTimeout, writeTimeout, idleTimeout } -func (s *Server) buildEntryPoints() map[string]*serverEntryPoint { - serverEntryPoints := make(map[string]*serverEntryPoint) - for entryPointName, entryPoint := range s.entryPoints { - serverEntryPoints[entryPointName] = &serverEntryPoint{ - httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()), - onDemandListener: entryPoint.OnDemandListener, - } - if entryPoint.CertificateStore != nil { - serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore.DynamicCerts - } else { - serverEntryPoints[entryPointName].certs = &safe.Safe{} - } - } - return serverEntryPoints -} - -// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one -// given a custom TLS configuration is passed and the passTLSCert option is set to true. -func (s *Server) getRoundTripper(entryPointName string, globalConfiguration configuration.GlobalConfiguration, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) { - if passTLSCert { - tlsConfig, err := createClientTLSConfig(entryPointName, tls) - if err != nil { - log.Errorf("Failed to create TLSClientConfig: %s", err) - return nil, err - } - - transport := createHTTPTransport(globalConfiguration) - transport.TLSClientConfig = tlsConfig - return transport, nil - } - - return s.defaultForwardingRoundTripper, nil -} - -// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic -// provider configurations. -func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]*serverEntryPoint, error) { - serverEntryPoints := s.buildEntryPoints() - redirectHandlers := make(map[string]negroni.Handler) - backends := map[string]http.Handler{} - backendsHealthCheck := map[string]*healthcheck.BackendHealthCheck{} - var errorPageHandlers []*errorpages.Handler - - errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{}) - - for providerName, config := range configurations { - frontendNames := sortedFrontendNamesForConfig(config) - frontend: - for _, frontendName := range frontendNames { - frontend := config.Frontends[frontendName] - - log.Debugf("Creating frontend %s", frontendName) - - var frontendEntryPoints []string - for _, entryPointName := range frontend.EntryPoints { - if _, ok := serverEntryPoints[entryPointName]; !ok { - log.Errorf("Undefined entrypoint '%s' for frontend %s", entryPointName, frontendName) - } else { - frontendEntryPoints = append(frontendEntryPoints, entryPointName) - } - } - frontend.EntryPoints = frontendEntryPoints - - if len(frontend.EntryPoints) == 0 { - log.Errorf("No entrypoint defined for frontend %s", frontendName) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - for _, entryPointName := range frontend.EntryPoints { - log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName) - - newServerRoute := &types.ServerRoute{Route: serverEntryPoints[entryPointName].httpRouter.GetHandler().NewRoute().Name(frontendName)} - for routeName, route := range frontend.Routes { - err := getRoute(newServerRoute, &route) - if err != nil { - log.Errorf("Error creating route for frontend %s: %v", frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - log.Debugf("Creating route %s %s", routeName, route.Rule) - } - - entryPoint := s.entryPoints[entryPointName].Configuration - n := negroni.New() - if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint { - if redirectHandlers[entryPointName] != nil { - n.Use(redirectHandlers[entryPointName]) - } else if handler, err := s.buildRedirectHandler(entryPointName, entryPoint.Redirect); err != nil { - log.Errorf("Error loading entrypoint configuration for frontend %s: %v", frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } else { - handlerToUse := s.wrapNegroniHandlerWithAccessLog(handler, fmt.Sprintf("entrypoint redirect for %s", frontendName)) - n.Use(handlerToUse) - redirectHandlers[entryPointName] = handlerToUse - } - } - - frontendHash, err := frontend.Hash() - if err != nil { - log.Errorf("Error calculating hash value for frontend %s: %v", frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - backendCacheKey := entryPointName + providerName + frontendHash - if backends[backendCacheKey] == nil { - log.Debugf("Creating backend %s", frontend.Backend) - - roundTripper, err := s.getRoundTripper(entryPointName, globalConfiguration, frontend.PassTLSCert, entryPoint.TLS) - if err != nil { - log.Errorf("Failed to create RoundTripper for frontend %s: %v", frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - - rewriter, err := NewHeaderRewriter(entryPoint.ForwardedHeaders.TrustedIPs, entryPoint.ForwardedHeaders.Insecure) - if err != nil { - log.Errorf("Error creating rewriter for frontend %s: %v", frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - - headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers) - secureMiddleware := middlewares.NewSecure(frontend.Headers) - - var responseModifier = buildModifyResponse(secureMiddleware, headerMiddleware) - var fwd http.Handler - - fwd, err = forward.New( - forward.Stream(true), - forward.PassHostHeader(frontend.PassHostHeader), - forward.RoundTripper(roundTripper), - forward.ErrorHandler(errorHandler), - forward.Rewriter(rewriter), - forward.ResponseModifier(responseModifier), - forward.BufferPool(s.bufferPool), - ) - - if err != nil { - log.Errorf("Error creating forwarder for frontend %s: %v", frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - - if s.tracingMiddleware.IsEnabled() { - tm := s.tracingMiddleware.NewForwarderMiddleware(frontendName, frontend.Backend) - - next := fwd - fwd = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tm.ServeHTTP(w, r, next.ServeHTTP) - }) - } - - var rr *roundrobin.RoundRobin - var saveFrontend http.Handler - if s.accessLoggerMiddleware != nil { - saveBackend := accesslog.NewSaveBackend(fwd, frontend.Backend) - saveFrontend = accesslog.NewSaveFrontend(saveBackend, frontendName) - rr, _ = roundrobin.New(saveFrontend) - } else { - rr, _ = roundrobin.New(fwd) - } - - if config.Backends[frontend.Backend] == nil { - log.Errorf("Undefined backend '%s' for frontend %s", frontend.Backend, frontendName) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - - lbMethod, err := types.NewLoadBalancerMethod(config.Backends[frontend.Backend].LoadBalancer) - if err != nil { - log.Errorf("Error loading load balancer method '%+v' for frontend %s: %v", config.Backends[frontend.Backend].LoadBalancer, frontendName, err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - - var sticky *roundrobin.StickySession - var cookieName string - if stickiness := config.Backends[frontend.Backend].LoadBalancer.Stickiness; stickiness != nil { - cookieName = cookie.GetName(stickiness.CookieName, frontend.Backend) - sticky = roundrobin.NewStickySession(cookieName) - } - - var lb http.Handler - switch lbMethod { - case types.Drr: - log.Debugf("Creating load-balancer drr") - rebalancer, _ := roundrobin.NewRebalancer(rr) - if sticky != nil { - log.Debugf("Sticky session with cookie %v", cookieName) - rebalancer, _ = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(sticky)) - } - lb = rebalancer - if err := s.configureLBServers(rebalancer, config, frontend); err != nil { - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - hcOpts := parseHealthCheckOptions(rebalancer, frontend.Backend, config.Backends[frontend.Backend].HealthCheck, globalConfiguration.HealthCheck) - if hcOpts != nil { - log.Debugf("Setting up backend health check %s", *hcOpts) - hcOpts.Transport = s.defaultForwardingRoundTripper - backendsHealthCheck[backendCacheKey] = healthcheck.NewBackendHealthCheck(*hcOpts, frontend.Backend) - } - lb = middlewares.NewEmptyBackendHandler(rebalancer, lb) - case types.Wrr: - log.Debugf("Creating load-balancer wrr") - if sticky != nil { - log.Debugf("Sticky session with cookie %v", cookieName) - if s.accessLoggerMiddleware != nil { - rr, _ = roundrobin.New(saveFrontend, roundrobin.EnableStickySession(sticky)) - } else { - rr, _ = roundrobin.New(fwd, roundrobin.EnableStickySession(sticky)) - } - } - lb = rr - if err := s.configureLBServers(rr, config, frontend); err != nil { - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - hcOpts := parseHealthCheckOptions(rr, frontend.Backend, config.Backends[frontend.Backend].HealthCheck, globalConfiguration.HealthCheck) - if hcOpts != nil { - log.Debugf("Setting up backend health check %s", *hcOpts) - hcOpts.Transport = s.defaultForwardingRoundTripper - backendsHealthCheck[backendCacheKey] = healthcheck.NewBackendHealthCheck(*hcOpts, frontend.Backend) - } - lb = middlewares.NewEmptyBackendHandler(rr, lb) - } - - if len(frontend.Errors) > 0 { - for errorPageName, errorPage := range frontend.Errors { - if frontend.Backend == errorPage.Backend { - log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).", - errorPageName, frontendName, errorPage.Backend) - } else if config.Backends[errorPage.Backend] == nil { - log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.", - errorPageName, frontendName, errorPage.Backend) - } else { - errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend) - if err != nil { - log.Errorf("Error creating error pages: %v", err) - } else { - if errorPageServer, ok := config.Backends[errorPage.Backend].Servers["error"]; ok { - errorPagesHandler.FallbackURL = errorPageServer.URL - } - - errorPageHandlers = append(errorPageHandlers, errorPagesHandler) - n.Use(errorPagesHandler) - } - } - } - } - - if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 { - lb, err = s.buildRateLimiter(lb, frontend.RateLimit) - if err != nil { - log.Errorf("Error creating rate limiter: %v", err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - lb = s.wrapHTTPHandlerWithAccessLog(lb, fmt.Sprintf("rate limit for %s", frontendName)) - } - - maxConns := config.Backends[frontend.Backend].MaxConn - if maxConns != nil && maxConns.Amount != 0 { - extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc) - if err != nil { - log.Errorf("Error creating connection limit: %v", err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - - log.Debugf("Creating load-balancer connection limit") - - lb, err = connlimit.New(lb, extractFunc, maxConns.Amount) - if err != nil { - log.Errorf("Error creating connection limit: %v", err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - lb = s.wrapHTTPHandlerWithAccessLog(lb, fmt.Sprintf("connection limit for %s", frontendName)) - } - - if globalConfiguration.Retry != nil { - countServers := len(config.Backends[frontend.Backend].Servers) - lb = s.buildRetryMiddleware(lb, globalConfiguration, countServers, frontend.Backend) - } - - if s.metricsRegistry.IsEnabled() { - n.Use(middlewares.NewBackendMetricsMiddleware(s.metricsRegistry, frontend.Backend)) - } - - ipWhitelistMiddleware, err := buildIPWhiteLister(frontend.WhiteList, frontend.WhitelistSourceRange) - if err != nil { - log.Errorf("Error creating IP Whitelister: %s", err) - } else if ipWhitelistMiddleware != nil { - n.Use( - s.tracingMiddleware.NewNegroniHandlerWrapper( - "IP whitelist", - s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for %s", frontendName)), - false)) - log.Debugf("Configured IP Whitelists: %s", frontend.WhitelistSourceRange) - } - - if frontend.Redirect != nil && entryPointName != frontend.Redirect.EntryPoint { - rewrite, err := s.buildRedirectHandler(entryPointName, frontend.Redirect) - if err != nil { - log.Errorf("Error creating Frontend Redirect: %v", err) - } else { - n.Use(s.wrapNegroniHandlerWithAccessLog(rewrite, fmt.Sprintf("frontend redirect for %s", frontendName))) - log.Debugf("Frontend %s redirect created", frontendName) - } - } - - if headerMiddleware != nil { - log.Debugf("Adding header middleware for frontend %s", frontendName) - n.Use(s.tracingMiddleware.NewNegroniHandlerWrapper("Header", headerMiddleware, false)) - } - - if secureMiddleware != nil { - log.Debugf("Adding secure middleware for frontend %s", frontendName) - n.UseFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly) - } - - if len(frontend.BasicAuth) > 0 { - users := types.Users{} - for _, user := range frontend.BasicAuth { - users = append(users, user) - } - - auth := &types.Auth{} - auth.Basic = &types.Basic{ - Users: users, - } - authMiddleware, err := mauth.NewAuthenticator(auth, s.tracingMiddleware) - if err != nil { - log.Errorf("Error creating Auth: %s", err) - } else { - n.Use(s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for %s", frontendName))) - } - } - - if config.Backends[frontend.Backend].Buffering != nil { - bufferedLb, err := s.buildBufferingMiddleware(lb, config.Backends[frontend.Backend].Buffering) - - if err != nil { - log.Errorf("Error setting up buffering middleware: %s", err) - } else { - lb = bufferedLb - } - } - - if config.Backends[frontend.Backend].CircuitBreaker != nil { - log.Debugf("Creating circuit breaker %s", config.Backends[frontend.Backend].CircuitBreaker.Expression) - expression := config.Backends[frontend.Backend].CircuitBreaker.Expression - circuitBreaker, err := middlewares.NewCircuitBreaker(lb, expression, middlewares.NewCircuitBreakerOptions(expression)) - if err != nil { - log.Errorf("Error creating circuit breaker: %v", err) - log.Errorf("Skipping frontend %s...", frontendName) - continue frontend - } - n.Use(s.tracingMiddleware.NewNegroniHandlerWrapper("Circuit breaker", circuitBreaker, false)) - } else { - n.UseHandler(lb) - } - backends[backendCacheKey] = n - } else { - log.Debugf("Reusing backend %s", frontend.Backend) - } - if frontend.Priority > 0 { - newServerRoute.Route.Priority(frontend.Priority) - } - s.wireFrontendBackend(newServerRoute, backends[backendCacheKey]) - - err = newServerRoute.Route.GetError() - if err != nil { - log.Errorf("Error building route: %s", err) - } - } - } - } - - for _, errorPageHandler := range errorPageHandlers { - if handler, ok := backends[errorPageHandler.BackendName]; ok { - errorPageHandler.PostLoad(handler) - } else { - errorPageHandler.PostLoad(nil) - } - } - - healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck) - - // Get new certificates list sorted per entrypoints - // Update certificates - entryPointsCertificates, err := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints) - - // Sort routes and update certificates - for serverEntryPointName, serverEntryPoint := range serverEntryPoints { - serverEntryPoint.httpRouter.GetHandler().SortRoutes() - if _, exists := entryPointsCertificates[serverEntryPointName]; exists { - serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName]) - } - } - - return serverEntryPoints, err -} - -func (s *Server) configureLBServers(lb healthcheck.LoadBalancer, config *types.Configuration, frontend *types.Frontend) error { - for name, srv := range config.Backends[frontend.Backend].Servers { - u, err := url.Parse(srv.URL) - if err != nil { - log.Errorf("Error parsing server URL %s: %v", srv.URL, err) - return err - } - log.Debugf("Creating server %s at %s with weight %d", name, u, srv.Weight) - if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil { - log.Errorf("Error adding server %s to load balancer: %v", srv.URL, err) - return err - } - s.metricsRegistry.BackendServerUpGauge().With("backend", frontend.Backend, "url", srv.URL).Set(1) - } - return nil -} - -func buildIPWhiteLister(whiteList *types.WhiteList, wlRange []string) (*middlewares.IPWhiteLister, error) { - if whiteList != nil && - len(whiteList.SourceRange) > 0 { - return middlewares.NewIPWhiteLister(whiteList.SourceRange, whiteList.UseXForwardedFor) - } else if len(wlRange) > 0 { - return middlewares.NewIPWhiteLister(wlRange, false) - } - return nil, nil -} - -func (s *Server) wireFrontendBackend(serverRoute *types.ServerRoute, handler http.Handler) { - // path replace - This needs to always be the very last on the handler chain (first in the order in this function) - // -- Replacing Path should happen at the very end of the Modifier chain, after all the Matcher+Modifiers ran - if len(serverRoute.ReplacePath) > 0 { - handler = &middlewares.ReplacePath{ - Path: serverRoute.ReplacePath, - Handler: handler, - } - } - - if len(serverRoute.ReplacePathRegex) > 0 { - sp := strings.Split(serverRoute.ReplacePathRegex, " ") - if len(sp) == 2 { - handler = middlewares.NewReplacePathRegexHandler(sp[0], sp[1], handler) - } else { - log.Warnf("Invalid syntax for ReplacePathRegex: %s. Separate the regular expression and the replacement by a space.", serverRoute.ReplacePathRegex) - } - } - - // add prefix - This needs to always be right before ReplacePath on the chain (second in order in this function) - // -- Adding Path Prefix should happen after all *Strip Matcher+Modifiers ran, but before Replace (in case it's configured) - if len(serverRoute.AddPrefix) > 0 { - handler = &middlewares.AddPrefix{ - Prefix: serverRoute.AddPrefix, - Handler: handler, - } - } - - // strip prefix - if len(serverRoute.StripPrefixes) > 0 { - handler = &middlewares.StripPrefix{ - Prefixes: serverRoute.StripPrefixes, - Handler: handler, - } - } - - // strip prefix with regex - if len(serverRoute.StripPrefixesRegex) > 0 { - handler = middlewares.NewStripPrefixRegex(handler, serverRoute.StripPrefixesRegex) - } - - serverRoute.Route.Handler(handler) -} - -func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) { - // entry point redirect - if len(opt.EntryPoint) > 0 { - entryPoint := s.entryPoints[opt.EntryPoint].Configuration - if entryPoint == nil { - return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName) - } - log.Debugf("Creating entry point redirect %s -> %s", srcEntryPointName, opt.EntryPoint) - return redirect.NewEntryPointHandler(entryPoint, opt.Permanent) - } - - // regex redirect - redirection, err := redirect.NewRegexHandler(opt.Regex, opt.Replacement, opt.Permanent) - if err != nil { - return nil, err - } - log.Debugf("Creating regex redirect %s -> %s -> %s", srcEntryPointName, opt.Regex, opt.Replacement) - - return redirection, nil -} - -func (s *Server) buildDefaultHTTPRouter() *mux.Router { - rt := mux.NewRouter() - rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(notFoundHandler), "backend not found") - rt.StrictSlash(true) - rt.SkipClean(true) - return rt -} - -func parseHealthCheckOptions(lb healthcheck.LoadBalancer, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options { - if hc == nil || hc.Path == "" || hcConfig == nil { - return nil - } - - interval := time.Duration(hcConfig.Interval) - if hc.Interval != "" { - intervalOverride, err := time.ParseDuration(hc.Interval) - switch { - case err != nil: - log.Errorf("Illegal health check interval for backend '%s': %s", backend, err) - case intervalOverride <= 0: - log.Errorf("Health check interval smaller than zero for backend '%s', backend", backend) - default: - interval = intervalOverride - } - } - - return &healthcheck.Options{ - Scheme: hc.Scheme, - Path: hc.Path, - Port: hc.Port, - Interval: interval, - LB: lb, - Hostname: hc.Hostname, - Headers: hc.Headers, - } -} - -func getRoute(serverRoute *types.ServerRoute, route *types.Route) error { - rules := rules.Rules{Route: serverRoute} - newRoute, err := rules.Parse(route.Rule) - if err != nil { - return err - } - newRoute.Priority(serverRoute.Route.GetPriority() + len(route.Rule)) - serverRoute.Route = newRoute - return nil -} - -func sortedFrontendNamesForConfig(configuration *types.Configuration) []string { - var keys []string - for key := range configuration.Frontends { - keys = append(keys, key) - } - sort.Strings(keys) - return keys -} - -func configureFrontends(frontends map[string]*types.Frontend, defaultEntrypoints []string) { - for _, frontend := range frontends { - // default endpoints if not defined in frontends - if len(frontend.EntryPoints) == 0 { - frontend.EntryPoints = defaultEntrypoints - } - } -} - -func configureBackends(backends map[string]*types.Backend) { - for backendName := range backends { - backend := backends[backendName] - if backend.LoadBalancer != nil && backend.LoadBalancer.Sticky { - log.Warnf("Deprecated configuration found: %s. Please use %s.", "backend.LoadBalancer.Sticky", "backend.LoadBalancer.Stickiness") - } - - _, err := types.NewLoadBalancerMethod(backend.LoadBalancer) - if err == nil { - if backend.LoadBalancer != nil && backend.LoadBalancer.Stickiness == nil && backend.LoadBalancer.Sticky { - backend.LoadBalancer.Stickiness = &types.Stickiness{ - CookieName: "_TRAEFIK_BACKEND", - } - } - } else { - log.Debugf("Backend %s: %v", backendName, err) - - var stickiness *types.Stickiness - if backend.LoadBalancer != nil { - if backend.LoadBalancer.Stickiness == nil { - if backend.LoadBalancer.Sticky { - stickiness = &types.Stickiness{ - CookieName: "_TRAEFIK_BACKEND", - } - } - } else { - stickiness = backend.LoadBalancer.Stickiness - } - } - backend.LoadBalancer = &types.LoadBalancer{ - Method: "wrr", - Stickiness: stickiness, - } - } - } -} - func registerMetricClients(metricsConfig *types.Metrics) metrics.Registry { if metricsConfig == nil { return metrics.NewVoidRegistry() @@ -1489,89 +575,3 @@ func stopMetricsClients() { metrics.StopStatsd() metrics.StopInfluxDB() } - -func (s *Server) buildRateLimiter(handler http.Handler, rlConfig *types.RateLimit) (http.Handler, error) { - extractFunc, err := utils.NewExtractor(rlConfig.ExtractorFunc) - if err != nil { - return nil, err - } - log.Debugf("Creating load-balancer rate limiter") - rateSet := ratelimit.NewRateSet() - for _, rate := range rlConfig.RateSet { - if err := rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil { - return nil, err - } - } - rateLimiter, err := ratelimit.New(handler, extractFunc, rateSet) - return s.tracingMiddleware.NewHTTPHandlerWrapper("Rate limit", rateLimiter, false), err - -} - -func (s *Server) buildRetryMiddleware(handler http.Handler, globalConfig configuration.GlobalConfiguration, countServers int, backendName string) http.Handler { - retryListeners := middlewares.RetryListeners{} - if s.metricsRegistry.IsEnabled() { - retryListeners = append(retryListeners, middlewares.NewMetricsRetryListener(s.metricsRegistry, backendName)) - } - if s.accessLoggerMiddleware != nil { - retryListeners = append(retryListeners, &accesslog.SaveRetries{}) - } - - retryAttempts := countServers - if globalConfig.Retry.Attempts > 0 { - retryAttempts = globalConfig.Retry.Attempts - } - - log.Debugf("Creating retries max attempts %d", retryAttempts) - - return s.tracingMiddleware.NewHTTPHandlerWrapper("Retry", middlewares.NewRetry(retryAttempts, handler, retryListeners), false) -} -func (s *Server) wrapNegroniHandlerWithAccessLog(handler negroni.Handler, frontendName string) negroni.Handler { - if s.accessLoggerMiddleware != nil { - saveBackend := accesslog.NewSaveNegroniBackend(handler, "Træfik") - saveFrontend := accesslog.NewSaveNegroniFrontend(saveBackend, frontendName) - return saveFrontend - } - return handler -} - -func (s *Server) wrapHTTPHandlerWithAccessLog(handler http.Handler, frontendName string) http.Handler { - if s.accessLoggerMiddleware != nil { - saveBackend := accesslog.NewSaveBackend(handler, "Træfik") - saveFrontend := accesslog.NewSaveFrontend(saveBackend, frontendName) - return saveFrontend - } - return handler -} - -func (s *Server) buildBufferingMiddleware(handler http.Handler, config *types.Buffering) (http.Handler, error) { - log.Debugf("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'", - config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, - config.MaxResponseBodyBytes, config.RetryExpression) - - return buffer.New( - handler, - buffer.MemRequestBodyBytes(config.MemRequestBodyBytes), - buffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes), - buffer.MemResponseBodyBytes(config.MemResponseBodyBytes), - buffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes), - buffer.CondSetter(len(config.RetryExpression) > 0, buffer.Retry(config.RetryExpression)), - ) -} - -func buildModifyResponse(secure *secure.Secure, header *middlewares.HeaderStruct) func(res *http.Response) error { - return func(res *http.Response) error { - if secure != nil { - err := secure.ModifyResponseHeaders(res) - if err != nil { - return err - } - } - if header != nil { - err := header.ModifyResponseHeaders(res) - if err != nil { - return err - } - } - return nil - } -} diff --git a/server/server_configuration.go b/server/server_configuration.go new file mode 100644 index 000000000..1e3ede0bf --- /dev/null +++ b/server/server_configuration.go @@ -0,0 +1,581 @@ +package server + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "reflect" + "sort" + "strings" + "time" + + "github.com/containous/mux" + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/healthcheck" + "github.com/containous/traefik/log" + "github.com/containous/traefik/metrics" + "github.com/containous/traefik/middlewares" + "github.com/containous/traefik/rules" + "github.com/containous/traefik/safe" + traefiktls "github.com/containous/traefik/tls" + "github.com/containous/traefik/types" + "github.com/eapache/channels" + "github.com/urfave/negroni" + "github.com/vulcand/oxy/forward" + "github.com/vulcand/oxy/utils" +) + +// loadConfiguration manages dynamically frontends, backends and TLS configurations +func (s *Server) loadConfiguration(configMsg types.ConfigMessage) { + currentConfigurations := s.currentConfigurations.Get().(types.Configurations) + + // Copy configurations to new map so we don't change current if LoadConfig fails + newConfigurations := make(types.Configurations) + for k, v := range currentConfigurations { + newConfigurations[k] = v + } + newConfigurations[configMsg.ProviderName] = configMsg.Configuration + + s.metricsRegistry.ConfigReloadsCounter().Add(1) + + newServerEntryPoints, err := s.loadConfig(newConfigurations, s.globalConfiguration) + if err != nil { + s.metricsRegistry.ConfigReloadsFailureCounter().Add(1) + s.metricsRegistry.LastConfigReloadFailureGauge().Set(float64(time.Now().Unix())) + log.Error("Error loading new configuration, aborted ", err) + return + } + + s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix())) + + for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints { + s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler()) + + if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil { + if newServerEntryPoint.certs.Get() != nil { + log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName) + } + } else { + s.serverEntryPoints[newServerEntryPointName].certs.Set(newServerEntryPoint.certs.Get()) + } + log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr) + } + + s.currentConfigurations.Set(newConfigurations) + + for _, listener := range s.configurationListeners { + listener(*configMsg.Configuration) + } + + s.postLoadConfiguration() +} + +// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic +// provider configurations. +func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]*serverEntryPoint, error) { + redirectHandlers, err := s.buildEntryPointRedirect() + if err != nil { + return nil, err + } + + serverEntryPoints := s.buildServerEntryPoints() + errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{}) + + backendsHandlers := map[string]http.Handler{} + backendsHealthCheck := map[string]*healthcheck.BackendConfig{} + + var postConfigs []handlerPostConfig + + for providerName, config := range configurations { + frontendNames := sortedFrontendNamesForConfig(config) + + for _, frontendName := range frontendNames { + frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config, + redirectHandlers, serverEntryPoints, errorHandler, + backendsHandlers, backendsHealthCheck) + if err != nil { + log.Errorf("%v. Skipping frontend %s...", err, frontendName) + } + + if len(frontendPostConfigs) > 0 { + postConfigs = append(postConfigs, frontendPostConfigs...) + } + } + } + + for _, postConfig := range postConfigs { + err := postConfig(backendsHandlers) + if err != nil { + log.Errorf("middleware post configuration error: %v", err) + } + } + + healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck) + + // Get new certificates list sorted per entrypoints + // Update certificates + entryPointsCertificates, err := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints) + // FIXME error management + + // Sort routes and update certificates + for serverEntryPointName, serverEntryPoint := range serverEntryPoints { + serverEntryPoint.httpRouter.GetHandler().SortRoutes() + if _, exists := entryPointsCertificates[serverEntryPointName]; exists { + serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName]) + } + } + + return serverEntryPoints, err +} + +func (s *Server) loadFrontendConfig( + providerName string, frontendName string, config *types.Configuration, + redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler, + backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig, +) ([]handlerPostConfig, error) { + + frontend := config.Frontends[frontendName] + + if len(frontend.EntryPoints) == 0 { + return nil, fmt.Errorf("no entrypoint defined for frontend %s", frontendName) + } + + backend := config.Backends[frontend.Backend] + if backend == nil { + return nil, fmt.Errorf("undefined backend '%s' for frontend %s", frontend.Backend, frontendName) + } + + frontendHash, err := frontend.Hash() + if err != nil { + return nil, fmt.Errorf("error calculating hash value for frontend %s: %v", frontendName, err) + } + + var postConfigs []handlerPostConfig + + for _, entryPointName := range frontend.EntryPoints { + log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName) + + entryPoint := s.entryPoints[entryPointName].Configuration + + if backendsHandlers[entryPointName+providerName+frontendHash] == nil { + log.Debugf("Creating backend %s", frontend.Backend) + + handlers, responseModifier, postConfig, err := s.buildMiddlewares(frontendName, frontend, config.Backends, entryPointName, entryPoint, providerName) + if err != nil { + return nil, err + } + + if postConfig != nil { + postConfigs = append(postConfigs, postConfig) + } + + fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier) + if err != nil { + return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err) + } + + lb, healthCheckConfig, err := s.buildBalancerMiddlewares(frontendName, frontend, backend, fwd) + if err != nil { + return nil, err + } + + if healthCheckConfig != nil { + backendsHealthCheck[entryPointName+providerName+frontendHash] = healthCheckConfig + } + + n := negroni.New() + + if _, exist := redirectHandlers[entryPointName]; exist { + n.Use(redirectHandlers[entryPointName]) + } + + for _, handler := range handlers { + n.Use(handler) + } + + n.UseHandler(lb) + + backendsHandlers[entryPointName+providerName+frontendHash] = n + } else { + log.Debugf("Reusing backend %s [%s - %s - %s - %s]", + frontend.Backend, entryPointName, providerName, frontendName, frontendHash) + } + + serverRoute, err := buildServerRoute(serverEntryPoints[entryPointName], frontendName, frontend) + if err != nil { + return nil, err + } + + handler := buildMatcherMiddlewares(serverRoute, backendsHandlers[entryPointName+providerName+frontendHash]) + serverRoute.Route.Handler(handler) + + err = serverRoute.Route.GetError() + if err != nil { + // FIXME error management + log.Errorf("Error building route: %s", err) + } + } + + return postConfigs, nil +} + +func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint, + frontendName string, frontend *types.Frontend, + errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) { + + roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS) + if err != nil { + return nil, fmt.Errorf("failed to create RoundTripper for frontend %s: %v", frontendName, err) + } + + rewriter, err := NewHeaderRewriter(entryPoint.ForwardedHeaders.TrustedIPs, entryPoint.ForwardedHeaders.Insecure) + if err != nil { + return nil, fmt.Errorf("error creating rewriter for frontend %s: %v", frontendName, err) + } + + var fwd http.Handler + fwd, err = forward.New( + forward.Stream(true), + forward.PassHostHeader(frontend.PassHostHeader), + forward.RoundTripper(roundTripper), + forward.ErrorHandler(errorHandler), + forward.Rewriter(rewriter), + forward.ResponseModifier(responseModifier), + forward.BufferPool(s.bufferPool), + ) + if err != nil { + return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err) + } + + if s.tracingMiddleware.IsEnabled() { + tm := s.tracingMiddleware.NewForwarderMiddleware(frontendName, frontend.Backend) + + next := fwd + fwd = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tm.ServeHTTP(w, r, next.ServeHTTP) + }) + } + + return fwd, nil +} + +func buildServerRoute(serverEntryPoint *serverEntryPoint, frontendName string, frontend *types.Frontend) (*types.ServerRoute, error) { + serverRoute := &types.ServerRoute{Route: serverEntryPoint.httpRouter.GetHandler().NewRoute().Name(frontendName)} + + priority := 0 + for routeName, route := range frontend.Routes { + rls := rules.Rules{Route: serverRoute} + newRoute, err := rls.Parse(route.Rule) + if err != nil { + return nil, fmt.Errorf("error creating route for frontend %s: %v", frontendName, err) + } + + serverRoute.Route = newRoute + + priority += len(route.Rule) + log.Debugf("Creating route %s %s", routeName, route.Rule) + } + + if frontend.Priority > 0 { + serverRoute.Route.Priority(frontend.Priority) + } else { + serverRoute.Route.Priority(priority) + } + + return serverRoute, nil +} + +func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) { + providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration) + s.defaultConfigurationValues(configMsg.Configuration) + currentConfigurations := s.currentConfigurations.Get().(types.Configurations) + + jsonConf, _ := json.Marshal(configMsg.Configuration) + + log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf)) + + if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLS == nil { + log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName) + return + } + + if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) { + log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName) + return + } + + providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName] + if !ok { + providerConfigUpdateCh = make(chan types.ConfigMessage) + s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh + s.routinesPool.Go(func(stop chan bool) { + s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop) + }) + } + + providerConfigUpdateCh <- configMsg +} + +func (s *Server) defaultConfigurationValues(configuration *types.Configuration) { + if configuration == nil || configuration.Frontends == nil { + return + } + s.configureFrontends(configuration.Frontends) + configureBackends(configuration.Backends) +} + +func (s *Server) configureFrontends(frontends map[string]*types.Frontend) { + defaultEntrypoints := s.globalConfiguration.DefaultEntryPoints + + for frontendName, frontend := range frontends { + // default endpoints if not defined in frontends + if len(frontend.EntryPoints) == 0 { + frontend.EntryPoints = defaultEntrypoints + } + + frontendEntryPoints, undefinedEntryPoints := s.filterEntryPoints(frontend.EntryPoints) + if len(undefinedEntryPoints) > 0 { + log.Errorf("Undefined entry point(s) '%s' for frontend %s", strings.Join(undefinedEntryPoints, ","), frontendName) + } + + frontend.EntryPoints = frontendEntryPoints + } +} + +func (s *Server) filterEntryPoints(entryPoints []string) ([]string, []string) { + var frontendEntryPoints []string + var undefinedEntryPoints []string + + for _, fepName := range entryPoints { + var exist bool + + for epName := range s.entryPoints { + if epName == fepName { + exist = true + break + } + } + + if exist { + frontendEntryPoints = append(frontendEntryPoints, fepName) + } else { + undefinedEntryPoints = append(undefinedEntryPoints, fepName) + } + } + + return frontendEntryPoints, undefinedEntryPoints +} + +func configureBackends(backends map[string]*types.Backend) { + for backendName := range backends { + backend := backends[backendName] + if backend.LoadBalancer != nil && backend.LoadBalancer.Sticky { + log.Warnf("Deprecated configuration found: %s. Please use %s.", "backend.LoadBalancer.Sticky", "backend.LoadBalancer.Stickiness") + } + + _, err := types.NewLoadBalancerMethod(backend.LoadBalancer) + if err == nil { + if backend.LoadBalancer != nil && backend.LoadBalancer.Stickiness == nil && backend.LoadBalancer.Sticky { + backend.LoadBalancer.Stickiness = &types.Stickiness{ + CookieName: "_TRAEFIK_BACKEND", + } + } + } else { + log.Debugf("Backend %s: %v", backendName, err) + + var stickiness *types.Stickiness + if backend.LoadBalancer != nil { + if backend.LoadBalancer.Stickiness == nil { + if backend.LoadBalancer.Sticky { + stickiness = &types.Stickiness{ + CookieName: "_TRAEFIK_BACKEND", + } + } + } else { + stickiness = backend.LoadBalancer.Stickiness + } + } + backend.LoadBalancer = &types.LoadBalancer{ + Method: "wrr", + Stickiness: stickiness, + } + } + } +} + +func (s *Server) listenConfigurations(stop chan bool) { + for { + select { + case <-stop: + return + case configMsg, ok := <-s.configurationValidatedChan: + if !ok || configMsg.Configuration == nil { + return + } + s.loadConfiguration(configMsg) + } + } +} + +// throttleProviderConfigReload throttles the configuration reload speed for a single provider. +// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration. +// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing, +// it will publish the last of the newly received configurations. +func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- types.ConfigMessage, in <-chan types.ConfigMessage, stop chan bool) { + ring := channels.NewRingChannel(1) + defer ring.Close() + + s.routinesPool.Go(func(stop chan bool) { + for { + select { + case <-stop: + return + case nextConfig := <-ring.Out(): + publish <- nextConfig.(types.ConfigMessage) + time.Sleep(throttle) + } + } + }) + + for { + select { + case <-stop: + return + case nextConfig := <-in: + ring.In() <- nextConfig + } + } +} + +func buildMatcherMiddlewares(serverRoute *types.ServerRoute, handler http.Handler) http.Handler { + // path replace - This needs to always be the very last on the handler chain (first in the order in this function) + // -- Replacing Path should happen at the very end of the Modifier chain, after all the Matcher+Modifiers ran + if len(serverRoute.ReplacePath) > 0 { + handler = &middlewares.ReplacePath{ + Path: serverRoute.ReplacePath, + Handler: handler, + } + } + + if len(serverRoute.ReplacePathRegex) > 0 { + sp := strings.Split(serverRoute.ReplacePathRegex, " ") + if len(sp) == 2 { + handler = middlewares.NewReplacePathRegexHandler(sp[0], sp[1], handler) + } else { + log.Warnf("Invalid syntax for ReplacePathRegex: %s. Separate the regular expression and the replacement by a space.", serverRoute.ReplacePathRegex) + } + } + + // add prefix - This needs to always be right before ReplacePath on the chain (second in order in this function) + // -- Adding Path Prefix should happen after all *Strip Matcher+Modifiers ran, but before Replace (in case it's configured) + if len(serverRoute.AddPrefix) > 0 { + handler = &middlewares.AddPrefix{ + Prefix: serverRoute.AddPrefix, + Handler: handler, + } + } + + // strip prefix + if len(serverRoute.StripPrefixes) > 0 { + handler = &middlewares.StripPrefix{ + Prefixes: serverRoute.StripPrefixes, + Handler: handler, + } + } + + // strip prefix with regex + if len(serverRoute.StripPrefixesRegex) > 0 { + handler = middlewares.NewStripPrefixRegex(handler, serverRoute.StripPrefixesRegex) + } + + return handler +} + +func (s *Server) postLoadConfiguration() { + if s.metricsRegistry.IsEnabled() { + activeConfig := s.currentConfigurations.Get().(types.Configurations) + metrics.OnConfigurationUpdate(activeConfig) + } + + if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() { + return + } + + if s.globalConfiguration.ACME.OnHostRule { + currentConfigurations := s.currentConfigurations.Get().(types.Configurations) + for _, config := range currentConfigurations { + for _, frontend := range config.Frontends { + + // check if one of the frontend entrypoints is configured with TLS + // and is configured with ACME + acmeEnabled := false + for _, entryPoint := range frontend.EntryPoints { + if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil { + acmeEnabled = true + break + } + } + + if acmeEnabled { + for _, route := range frontend.Routes { + rls := rules.Rules{} + domains, err := rls.ParseDomains(route.Rule) + if err != nil { + log.Errorf("Error parsing domains: %v", err) + } else { + s.globalConfiguration.ACME.LoadCertificateForDomains(domains) + } + } + } + } + } + } +} + +// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically +func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) (map[string]map[string]*tls.Certificate, error) { + newEPCertificates := make(map[string]map[string]*tls.Certificate) + // Get all certificates + for _, config := range configurations { + if config.TLS != nil && len(config.TLS) > 0 { + if err := traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, defaultEntryPoints); err != nil { + return nil, err + } + } + } + return newEPCertificates, nil +} + +func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint { + serverEntryPoints := make(map[string]*serverEntryPoint) + for entryPointName, entryPoint := range s.entryPoints { + serverEntryPoints[entryPointName] = &serverEntryPoint{ + httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()), + onDemandListener: entryPoint.OnDemandListener, + } + if entryPoint.CertificateStore != nil { + serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore.DynamicCerts + } else { + serverEntryPoints[entryPointName].certs = &safe.Safe{} + } + } + return serverEntryPoints +} + +func (s *Server) buildDefaultHTTPRouter() *mux.Router { + rt := mux.NewRouter() + rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(http.NotFound), "backend not found") + rt.StrictSlash(true) + rt.SkipClean(true) + return rt +} + +func sortedFrontendNamesForConfig(configuration *types.Configuration) []string { + var keys []string + for key := range configuration.Frontends { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} diff --git a/server/server_configuration_test.go b/server/server_configuration_test.go new file mode 100644 index 000000000..abe22a1d9 --- /dev/null +++ b/server/server_configuration_test.go @@ -0,0 +1,484 @@ +package server + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/containous/flaeg" + "github.com/containous/mux" + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/healthcheck" + "github.com/containous/traefik/rules" + th "github.com/containous/traefik/testhelpers" + "github.com/containous/traefik/tls" + "github.com/containous/traefik/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vulcand/oxy/roundrobin" +) + +// LocalhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var ( + localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE----- +MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 +iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul +rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO +BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw +AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA +AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 +tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs +h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM +fblo6RBxUQ== +-----END CERTIFICATE-----`) + + // LocalhostKey is the private key for localhostCert. + localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 +SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB +l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB +AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet +3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb +uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H +qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp +jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY +fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U +fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU +y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX +qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo +f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== +-----END RSA PRIVATE KEY-----`) +) + +type testLoadBalancer struct{} + +func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // noop +} + +func (lb *testLoadBalancer) RemoveServer(u *url.URL) error { + return nil +} + +func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { + return nil +} + +func (lb *testLoadBalancer) Servers() []*url.URL { + return []*url.URL{} +} + +func TestServerLoadConfigHealthCheckOptions(t *testing.T) { + healthChecks := []*types.HealthCheck{ + nil, + { + Path: "/path", + }, + } + + for _, lbMethod := range []string{"Wrr", "Drr"} { + for _, healthCheck := range healthChecks { + t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) { + globalConfig := configuration.GlobalConfiguration{ + HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)}, + } + entryPoints := map[string]EntryPoint{ + "http": { + Configuration: &configuration.EntryPoint{ + ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, + }, + }, + } + + dynamicConfigs := types.Configurations{ + "config": &types.Configuration{ + Frontends: map[string]*types.Frontend{ + "frontend": { + EntryPoints: []string{"http"}, + Backend: "backend", + }, + }, + Backends: map[string]*types.Backend{ + "backend": { + Servers: map[string]types.Server{ + "server": { + URL: "http://localhost", + }, + }, + LoadBalancer: &types.LoadBalancer{ + Method: lbMethod, + }, + HealthCheck: healthCheck, + }, + }, + TLS: []*tls.Configuration{ + { + Certificate: &tls.Certificate{ + CertFile: localhostCert, + KeyFile: localhostKey, + }, + EntryPoints: []string{"http"}, + }, + }, + }, + } + + srv := NewServer(globalConfig, nil, entryPoints) + + _, err := srv.loadConfig(dynamicConfigs, globalConfig) + require.NoError(t, err) + + expectedNumHealthCheckBackends := 0 + if healthCheck != nil { + expectedNumHealthCheckBackends = 1 + } + assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends") + }) + } + } +} + +func TestServerLoadConfigEmptyBasicAuth(t *testing.T) { + globalConfig := configuration.GlobalConfiguration{ + EntryPoints: configuration.EntryPoints{ + "http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}, + }, + } + + dynamicConfigs := types.Configurations{ + "config": &types.Configuration{ + Frontends: map[string]*types.Frontend{ + "frontend": { + EntryPoints: []string{"http"}, + Backend: "backend", + BasicAuth: []string{""}, + }, + }, + Backends: map[string]*types.Backend{ + "backend": { + Servers: map[string]types.Server{ + "server": { + URL: "http://localhost", + }, + }, + LoadBalancer: &types.LoadBalancer{ + Method: "Wrr", + }, + }, + }, + }, + } + + entryPoints := map[string]EntryPoint{} + for key, value := range globalConfig.EntryPoints { + entryPoints[key] = EntryPoint{ + Configuration: value, + } + } + + srv := NewServer(globalConfig, nil, entryPoints) + _, err := srv.loadConfig(dynamicConfigs, globalConfig) + require.NoError(t, err) +} + +func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) { + globalConfig := configuration.GlobalConfiguration{ + DefaultEntryPoints: []string{"http", "https"}, + } + entryPoints := map[string]EntryPoint{ + "https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}}, + "http": {Configuration: &configuration.EntryPoint{}}, + } + + dynamicConfigs := types.Configurations{ + "config": &types.Configuration{ + TLS: []*tls.Configuration{ + { + Certificate: &tls.Certificate{ + CertFile: localhostCert, + KeyFile: localhostKey, + }, + }, + }, + }, + } + + srv := NewServer(globalConfig, nil, entryPoints) + if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil { + t.Fatalf("got error: %s", err) + } else if mapEntryPoints["https"].certs.Get() == nil { + t.Fatal("got error: https entryPoint must have TLS certificates.") + } +} + +func TestReuseBackend(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + + globalConfig := configuration.GlobalConfiguration{ + DefaultEntryPoints: []string{"http"}, + } + + entryPoints := map[string]EntryPoint{ + "http": {Configuration: &configuration.EntryPoint{ + ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, + }}, + } + + dynamicConfigs := types.Configurations{ + "config": th.BuildConfiguration( + th.WithFrontends( + th.WithFrontend("backend", + th.WithFrontendName("frontend0"), + th.WithEntryPoints("http"), + th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))), + th.WithFrontend("backend", + th.WithFrontendName("frontend1"), + th.WithEntryPoints("http"), + th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")), + th.WithBasicAuth("foo", "bar")), + ), + th.WithBackends(th.WithBackendNew("backend", + th.WithLBMethod("wrr"), + th.WithServersNew(th.WithServerNew(testServer.URL))), + ), + ), + } + + srv := NewServer(globalConfig, nil, entryPoints) + + serverEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig) + if err != nil { + t.Fatalf("error loading config: %s", err) + } + + // Test that the /ok path returns a status 200. + responseRecorderOk := &httptest.ResponseRecorder{} + requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil) + serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk) + + assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code") + + // Test that the /unauthorized path returns a 401 because of + // the basic authentication defined on the frontend. + responseRecorderUnauthorized := &httptest.ResponseRecorder{} + requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil) + serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized) + + assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code") +} + +func TestThrottleProviderConfigReload(t *testing.T) { + throttleDuration := 30 * time.Millisecond + publishConfig := make(chan types.ConfigMessage) + providerConfig := make(chan types.ConfigMessage) + stop := make(chan bool) + defer func() { + stop <- true + }() + + globalConfig := configuration.GlobalConfiguration{} + server := NewServer(globalConfig, nil, nil) + + go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop) + + publishedConfigCount := 0 + stopConsumeConfigs := make(chan bool) + go func() { + for { + select { + case <-stop: + return + case <-stopConsumeConfigs: + return + case <-publishConfig: + publishedConfigCount++ + } + } + }() + + // publish 5 new configs, one new config each 10 milliseconds + for i := 0; i < 5; i++ { + providerConfig <- types.ConfigMessage{} + time.Sleep(10 * time.Millisecond) + } + + // after 50 milliseconds 5 new configs were published + // with a throttle duration of 30 milliseconds this means, we should have received 2 new configs + assert.Equal(t, 2, publishedConfigCount, "times configs were published") + + stopConsumeConfigs <- true + + select { + case <-publishConfig: + // There should be exactly one more message that we receive after ~60 milliseconds since the start of the test. + select { + case <-publishConfig: + t.Error("extra config publication found") + case <-time.After(100 * time.Millisecond): + return + } + case <-time.After(100 * time.Millisecond): + t.Error("Last config was not published in time") + } +} + +func TestServerMultipleFrontendRules(t *testing.T) { + testCases := []struct { + expression string + requestURL string + expectedURL string + }{ + { + expression: "Host:foo.bar", + requestURL: "http://foo.bar", + expectedURL: "http://foo.bar", + }, + { + expression: "PathPrefix:/management;ReplacePath:/health", + requestURL: "http://foo.bar/management", + expectedURL: "http://foo.bar/health", + }, + { + expression: "Host:foo.bar;AddPrefix:/blah", + requestURL: "http://foo.bar/baz", + expectedURL: "http://foo.bar/blah/baz", + }, + { + expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}", + requestURL: "http://foo.bar/one/some/12345/four", + expectedURL: "http://foo.bar/four", + }, + { + expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero", + requestURL: "http://foo.bar/one/some/12345/four", + expectedURL: "http://foo.bar/zero/four", + }, + { + expression: "AddPrefix:/blah;ReplacePath:/baz", + requestURL: "http://foo.bar/hello", + expectedURL: "http://foo.bar/baz", + }, + { + expression: "PathPrefixStrip:/management;ReplacePath:/health", + requestURL: "http://foo.bar/management", + expectedURL: "http://foo.bar/health", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.expression, func(t *testing.T) { + t.Parallel() + + router := mux.NewRouter() + route := router.NewRoute() + serverRoute := &types.ServerRoute{Route: route} + rls := &rules.Rules{Route: serverRoute} + + expression := test.expression + routeResult, err := rls.Parse(expression) + + if err != nil { + t.Fatalf("Error while building route for %s: %+v", expression, err) + } + + request := th.MustNewRequest(http.MethodGet, test.requestURL, nil) + routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) + + if !routeMatch { + t.Fatalf("Rule %s doesn't match", expression) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, test.expectedURL, r.URL.String(), "URL") + }) + + hd := buildMatcherMiddlewares(serverRoute, handler) + serverRoute.Route.Handler(hd) + + serverRoute.Route.GetHandler().ServeHTTP(nil, request) + }) + } +} + +func TestServerBuildHealthCheckOptions(t *testing.T) { + lb := &testLoadBalancer{} + globalInterval := 15 * time.Second + + testCases := []struct { + desc string + hc *types.HealthCheck + expectedOpts *healthcheck.Options + }{ + { + desc: "nil health check", + hc: nil, + expectedOpts: nil, + }, + { + desc: "empty path", + hc: &types.HealthCheck{ + Path: "", + }, + expectedOpts: nil, + }, + { + desc: "unparseable interval", + hc: &types.HealthCheck{ + Path: "/path", + Interval: "unparseable", + }, + expectedOpts: &healthcheck.Options{ + Path: "/path", + Interval: globalInterval, + LB: lb, + }, + }, + { + desc: "sub-zero interval", + hc: &types.HealthCheck{ + Path: "/path", + Interval: "-42s", + }, + expectedOpts: &healthcheck.Options{ + Path: "/path", + Interval: globalInterval, + LB: lb, + }, + }, + { + desc: "parseable interval", + hc: &types.HealthCheck{ + Path: "/path", + Interval: "5m", + }, + expectedOpts: &healthcheck.Options{ + Path: "/path", + Interval: 5 * time.Minute, + LB: lb, + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + opts := buildHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)}) + assert.Equal(t, test.expectedOpts, opts, "health check options") + }) + } +} diff --git a/server/server_loadbalancer.go b/server/server_loadbalancer.go new file mode 100644 index 000000000..861f8f010 --- /dev/null +++ b/server/server_loadbalancer.go @@ -0,0 +1,428 @@ +package server + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "time" + + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/healthcheck" + "github.com/containous/traefik/log" + "github.com/containous/traefik/middlewares" + "github.com/containous/traefik/middlewares/accesslog" + "github.com/containous/traefik/server/cookie" + traefiktls "github.com/containous/traefik/tls" + "github.com/containous/traefik/types" + "github.com/vulcand/oxy/buffer" + "github.com/vulcand/oxy/connlimit" + "github.com/vulcand/oxy/ratelimit" + "github.com/vulcand/oxy/roundrobin" + "github.com/vulcand/oxy/utils" + "golang.org/x/net/http2" +) + +type h2cTransportWrapper struct { + *http2.Transport +} + +func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + return t.Transport.RoundTrip(req) +} + +func (s *Server) buildBalancerMiddlewares(frontendName string, frontend *types.Frontend, backend *types.Backend, fwd http.Handler) (http.Handler, *healthcheck.BackendConfig, error) { + balancer, err := s.buildLoadBalancer(frontendName, frontend.Backend, backend, fwd) + if err != nil { + return nil, nil, err + } + + // Health Check + var backendHealthCheck *healthcheck.BackendConfig + if hcOpts := buildHealthCheckOptions(balancer, frontend.Backend, backend.HealthCheck, s.globalConfiguration.HealthCheck); hcOpts != nil { + log.Debugf("Setting up backend health check %s", *hcOpts) + + hcOpts.Transport = s.defaultForwardingRoundTripper + backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, frontend.Backend) + } + + // Empty (backend with no servers) + var lb http.Handler = middlewares.NewEmptyBackendHandler(balancer) + + // Rate Limit + if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 { + handler, err := buildRateLimiter(lb, frontend.RateLimit) + if err != nil { + return nil, nil, fmt.Errorf("error creating rate limiter: %v", err) + } + + lb = s.wrapHTTPHandlerWithAccessLog( + s.tracingMiddleware.NewHTTPHandlerWrapper("Rate limit", handler, false), + fmt.Sprintf("rate limit for %s", frontendName), + ) + } + + // Max Connections + if backend.MaxConn != nil && backend.MaxConn.Amount != 0 { + log.Debugf("Creating load-balancer connection limit") + + handler, err := buildMaxConn(lb, backend.MaxConn) + if err != nil { + return nil, nil, err + } + lb = s.wrapHTTPHandlerWithAccessLog(handler, fmt.Sprintf("connection limit for %s", frontendName)) + } + + // Retry + if s.globalConfiguration.Retry != nil { + handler := s.buildRetryMiddleware(lb, s.globalConfiguration.Retry, len(backend.Servers), frontend.Backend) + lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Retry", handler, false) + } + + // Buffering + if backend.Buffering != nil { + handler, err := buildBufferingMiddleware(lb, backend.Buffering) + if err != nil { + return nil, nil, fmt.Errorf("error setting up buffering middleware: %s", err) + } + + // TODO refactor ? + lb = handler + } + + // Circuit Breaker + if backend.CircuitBreaker != nil { + log.Debugf("Creating circuit breaker %s", backend.CircuitBreaker.Expression) + + expression := backend.CircuitBreaker.Expression + circuitBreaker, err := middlewares.NewCircuitBreaker(lb, expression, middlewares.NewCircuitBreakerOptions(expression)) + if err != nil { + return nil, nil, fmt.Errorf("error creating circuit breaker: %v", err) + } + + lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Circuit breaker", circuitBreaker, false) + } + + return lb, backendHealthCheck, nil +} + +func (s *Server) buildLoadBalancer(frontendName string, backendName string, backend *types.Backend, fwd http.Handler) (healthcheck.BalancerHandler, error) { + var rr *roundrobin.RoundRobin + var saveFrontend http.Handler + + if s.accessLoggerMiddleware != nil { + saveBackend := accesslog.NewSaveBackend(fwd, backendName) + saveFrontend = accesslog.NewSaveFrontend(saveBackend, frontendName) + rr, _ = roundrobin.New(saveFrontend) + } else { + rr, _ = roundrobin.New(fwd) + } + + var stickySession *roundrobin.StickySession + var cookieName string + if stickiness := backend.LoadBalancer.Stickiness; stickiness != nil { + cookieName = cookie.GetName(stickiness.CookieName, backendName) + stickySession = roundrobin.NewStickySession(cookieName) + } + + lbMethod, err := types.NewLoadBalancerMethod(backend.LoadBalancer) + if err != nil { + return nil, fmt.Errorf("error loading load balancer method '%+v' for frontend %s: %v", backend.LoadBalancer, frontendName, err) + } + + var lb healthcheck.BalancerHandler + + switch lbMethod { + case types.Drr: + log.Debug("Creating load-balancer drr") + + if stickySession != nil { + log.Debugf("Sticky session with cookie %v", cookieName) + + lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession)) + if err != nil { + return nil, err + } + } else { + lb, err = roundrobin.NewRebalancer(rr) + if err != nil { + return nil, err + } + } + case types.Wrr: + log.Debug("Creating load-balancer wrr") + + if stickySession != nil { + log.Debugf("Sticky session with cookie %v", cookieName) + + if s.accessLoggerMiddleware != nil { + lb, err = roundrobin.New(saveFrontend, roundrobin.EnableStickySession(stickySession)) + if err != nil { + return nil, err + } + } else { + lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession)) + if err != nil { + return nil, err + } + } + } else { + lb = rr + } + default: + return nil, fmt.Errorf("invalid load-balancing method %q", lbMethod) + } + + if err := s.configureLBServers(lb, backend, backendName); err != nil { + return nil, fmt.Errorf("error configuring load balancer for frontend %s: %v", frontendName, err) + } + + return lb, nil +} + +func (s *Server) configureLBServers(lb healthcheck.BalancerHandler, backend *types.Backend, backendName string) error { + for name, srv := range backend.Servers { + u, err := url.Parse(srv.URL) + if err != nil { + return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err) + } + + log.Debugf("Creating server %s at %s with weight %d", name, u, srv.Weight) + + if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil { + return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err) + } + + s.metricsRegistry.BackendServerUpGauge().With("backend", backendName, "url", srv.URL).Set(1) + } + return nil +} + +// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one +// given a custom TLS configuration is passed and the passTLSCert option is set to true. +func (s *Server) getRoundTripper(entryPointName string, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) { + if passTLSCert { + tlsConfig, err := createClientTLSConfig(entryPointName, tls) + if err != nil { + return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err) + } + + transport, err := createHTTPTransport(s.globalConfiguration) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %v", err) + } + + transport.TLSClientConfig = tlsConfig + return transport, nil + } + + return s.defaultForwardingRoundTripper, nil +} + +// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings. +// For the settings that can't be configured in Traefik it uses the default http.Transport settings. +// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost +// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing +// behaviour and backwards compatibility issues. +func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) { + dialer := &net.Dialer{ + Timeout: configuration.DefaultDialTimeout, + KeepAlive: 30 * time.Second, + DualStack: true, + } + + if globalConfiguration.ForwardingTimeouts != nil { + dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout) + } + + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + transport.RegisterProtocol("h2c", &h2cTransportWrapper{ + Transport: &http2.Transport{ + DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { + return net.Dial(netw, addr) + }, + AllowHTTP: true, + }, + }) + + if globalConfiguration.ForwardingTimeouts != nil { + transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout) + } + + if globalConfiguration.InsecureSkipVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + if len(globalConfiguration.RootCAs) > 0 { + transport.TLSClientConfig = &tls.Config{ + RootCAs: createRootCACertPool(globalConfiguration.RootCAs), + } + } + + err := http2.ConfigureTransport(transport) + if err != nil { + return nil, err + } + + return transport, nil +} + +func createRootCACertPool(rootCAs traefiktls.RootCAs) *x509.CertPool { + roots := x509.NewCertPool() + + for _, cert := range rootCAs { + certContent, err := cert.Read() + if err != nil { + log.Error("Error while read RootCAs", err) + continue + } + roots.AppendCertsFromPEM(certContent) + } + + return roots +} + +func createClientTLSConfig(entryPointName string, tlsOption *traefiktls.TLS) (*tls.Config, error) { + if tlsOption == nil { + return nil, errors.New("no TLS provided") + } + + config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName) + if err != nil { + return nil, err + } + + if len(tlsOption.ClientCAFiles) > 0 { + log.Warnf("Deprecated configuration found during client TLS configuration creation: %s. Please use %s (which allows to make the CA Files optional).", "tls.ClientCAFiles", "tls.ClientCA.files") + tlsOption.ClientCA.Files = tlsOption.ClientCAFiles + tlsOption.ClientCA.Optional = false + } + + if len(tlsOption.ClientCA.Files) > 0 { + pool := x509.NewCertPool() + for _, caFile := range tlsOption.ClientCA.Files { + data, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + + if !pool.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("invalid certificate(s) in %s", caFile) + } + } + config.RootCAs = pool + } + + config.BuildNameToCertificate() + + return config, nil +} + +func (s *Server) buildRetryMiddleware(handler http.Handler, retry *configuration.Retry, countServers int, backendName string) http.Handler { + retryListeners := middlewares.RetryListeners{} + if s.metricsRegistry.IsEnabled() { + retryListeners = append(retryListeners, middlewares.NewMetricsRetryListener(s.metricsRegistry, backendName)) + } + if s.accessLoggerMiddleware != nil { + retryListeners = append(retryListeners, &accesslog.SaveRetries{}) + } + + retryAttempts := countServers + if retry.Attempts > 0 { + retryAttempts = retry.Attempts + } + + log.Debugf("Creating retries max attempts %d", retryAttempts) + + return middlewares.NewRetry(retryAttempts, handler, retryListeners) +} + +func buildRateLimiter(handler http.Handler, rlConfig *types.RateLimit) (http.Handler, error) { + extractFunc, err := utils.NewExtractor(rlConfig.ExtractorFunc) + if err != nil { + return nil, err + } + + log.Debugf("Creating load-balancer rate limiter") + + rateSet := ratelimit.NewRateSet() + for _, rate := range rlConfig.RateSet { + if err := rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil { + return nil, err + } + } + + return ratelimit.New(handler, extractFunc, rateSet) +} + +func buildBufferingMiddleware(handler http.Handler, config *types.Buffering) (http.Handler, error) { + log.Debugf("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'", + config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, + config.MaxResponseBodyBytes, config.RetryExpression) + + return buffer.New( + handler, + buffer.MemRequestBodyBytes(config.MemRequestBodyBytes), + buffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes), + buffer.MemResponseBodyBytes(config.MemResponseBodyBytes), + buffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes), + buffer.CondSetter(len(config.RetryExpression) > 0, buffer.Retry(config.RetryExpression)), + ) +} + +func buildMaxConn(lb http.Handler, maxConns *types.MaxConn) (http.Handler, error) { + extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc) + if err != nil { + return nil, fmt.Errorf("error creating connection limit: %v", err) + } + + log.Debugf("Creating load-balancer connection limit") + + handler, err := connlimit.New(lb, extractFunc, maxConns.Amount) + if err != nil { + return nil, fmt.Errorf("error creating connection limit: %v", err) + } + + return handler, nil +} + +func buildHealthCheckOptions(lb healthcheck.BalancerHandler, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options { + if hc == nil || hc.Path == "" || hcConfig == nil { + return nil + } + + interval := time.Duration(hcConfig.Interval) + if hc.Interval != "" { + intervalOverride, err := time.ParseDuration(hc.Interval) + if err != nil { + log.Errorf("Illegal health check interval for backend '%s': %s", backend, err) + } else if intervalOverride <= 0 { + log.Errorf("Health check interval smaller than zero for backend '%s', backend", backend) + } else { + interval = intervalOverride + } + } + + return &healthcheck.Options{ + Scheme: hc.Scheme, + Path: hc.Path, + Port: hc.Port, + Interval: interval, + LB: lb, + Hostname: hc.Hostname, + Headers: hc.Headers, + } +} diff --git a/server/server_loadbalancer_test.go b/server/server_loadbalancer_test.go new file mode 100644 index 000000000..9af963a5d --- /dev/null +++ b/server/server_loadbalancer_test.go @@ -0,0 +1,81 @@ +package server + +import ( + "testing" + + "github.com/containous/traefik/types" + "github.com/stretchr/testify/assert" +) + +func TestConfigureBackends(t *testing.T) { + validMethod := "Drr" + defaultMethod := "wrr" + + testCases := []struct { + desc string + lb *types.LoadBalancer + expectedMethod string + expectedStickiness *types.Stickiness + }{ + { + desc: "valid load balancer method with sticky enabled", + lb: &types.LoadBalancer{ + Method: validMethod, + Stickiness: &types.Stickiness{}, + }, + expectedMethod: validMethod, + expectedStickiness: &types.Stickiness{}, + }, + { + desc: "valid load balancer method with sticky disabled", + lb: &types.LoadBalancer{ + Method: validMethod, + Stickiness: nil, + }, + expectedMethod: validMethod, + }, + { + desc: "invalid load balancer method with sticky enabled", + lb: &types.LoadBalancer{ + Method: "Invalid", + Stickiness: &types.Stickiness{}, + }, + expectedMethod: defaultMethod, + expectedStickiness: &types.Stickiness{}, + }, + { + desc: "invalid load balancer method with sticky disabled", + lb: &types.LoadBalancer{ + Method: "Invalid", + Stickiness: nil, + }, + expectedMethod: defaultMethod, + }, + { + desc: "missing load balancer", + lb: nil, + expectedMethod: defaultMethod, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + backend := &types.Backend{ + LoadBalancer: test.lb, + } + + configureBackends(map[string]*types.Backend{ + "backend": backend, + }) + + expected := types.LoadBalancer{ + Method: test.expectedMethod, + Stickiness: test.expectedStickiness, + } + + assert.Equal(t, expected, *backend.LoadBalancer) + }) + } +} diff --git a/server/server_middlewares.go b/server/server_middlewares.go new file mode 100644 index 000000000..58f4c9034 --- /dev/null +++ b/server/server_middlewares.go @@ -0,0 +1,316 @@ +package server + +import ( + "fmt" + "net/http" + + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/log" + "github.com/containous/traefik/middlewares" + "github.com/containous/traefik/middlewares/accesslog" + mauth "github.com/containous/traefik/middlewares/auth" + "github.com/containous/traefik/middlewares/errorpages" + "github.com/containous/traefik/middlewares/redirect" + "github.com/containous/traefik/types" + thoas_stats "github.com/thoas/stats" + "github.com/unrolled/secure" + "github.com/urfave/negroni" +) + +type handlerPostConfig func(backendsHandlers map[string]http.Handler) error + +type modifyResponse func(*http.Response) error + +func (s *Server) buildMiddlewares(frontendName string, frontend *types.Frontend, + backends map[string]*types.Backend, + entryPointName string, entryPoint *configuration.EntryPoint, + providerName string) ([]negroni.Handler, modifyResponse, handlerPostConfig, error) { + + var middle []negroni.Handler + var postConfig handlerPostConfig + + // Error pages + if len(frontend.Errors) > 0 { + handlers, err := buildErrorPagesMiddleware(frontendName, frontend, backends, entryPointName, providerName) + if err != nil { + return nil, nil, nil, err + } + + postConfig = errorPagesPostConfig(handlers) + + for _, handler := range handlers { + middle = append(middle, handler) + } + } + + // Metrics + if s.metricsRegistry.IsEnabled() { + handler := middlewares.NewBackendMetricsMiddleware(s.metricsRegistry, frontend.Backend) + middle = append(middle, handler) + } + + // Whitelist + ipWhitelistMiddleware, err := buildIPWhiteLister(frontend.WhiteList, frontend.WhitelistSourceRange) + if err != nil { + return nil, nil, nil, fmt.Errorf("error creating IP Whitelister: %s", err) + } + if ipWhitelistMiddleware != nil { + log.Debugf("Configured IP Whitelists: %v", frontend.WhiteList.SourceRange) + + handler := s.tracingMiddleware.NewNegroniHandlerWrapper( + "IP whitelist", + s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for %s", frontendName)), + false) + middle = append(middle, handler) + } + + // Redirect + if frontend.Redirect != nil && entryPointName != frontend.Redirect.EntryPoint { + rewrite, err := s.buildRedirectHandler(entryPointName, frontend.Redirect) + if err != nil { + return nil, nil, nil, fmt.Errorf("error creating Frontend Redirect: %v", err) + } + + handler := s.wrapNegroniHandlerWithAccessLog(rewrite, fmt.Sprintf("frontend redirect for %s", frontendName)) + middle = append(middle, handler) + + log.Debugf("Frontend %s redirect created", frontendName) + } + + // Header + headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers) + if headerMiddleware != nil { + log.Debugf("Adding header middleware for frontend %s", frontendName) + + handler := s.tracingMiddleware.NewNegroniHandlerWrapper("Header", headerMiddleware, false) + middle = append(middle, handler) + } + + // Secure + secureMiddleware := middlewares.NewSecure(frontend.Headers) + if secureMiddleware != nil { + log.Debugf("Adding secure middleware for frontend %s", frontendName) + + handler := negroni.HandlerFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly) + middle = append(middle, handler) + } + + // Basic auth + if len(frontend.BasicAuth) > 0 { + log.Debugf("Adding basic authentication for frontend %s", frontendName) + + authMiddleware, err := s.buildBasicAuthMiddleware(frontend.BasicAuth) + if err != nil { + return nil, nil, nil, err + } + + handler := s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Basic Auth for %s", frontendName)) + middle = append(middle, handler) + } + + return middle, buildModifyResponse(secureMiddleware, headerMiddleware), postConfig, nil +} + +func (s *Server) buildServerEntryPointMiddlewares(serverEntryPointName string, serverEntryPoint *serverEntryPoint) ([]negroni.Handler, error) { + serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()} + + if s.tracingMiddleware.IsEnabled() { + serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(serverEntryPointName)) + } + + if s.accessLoggerMiddleware != nil { + serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware) + } + + if s.metricsRegistry.IsEnabled() { + serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, serverEntryPointName)) + } + + if s.globalConfiguration.API != nil { + if s.globalConfiguration.API.Stats == nil { + s.globalConfiguration.API.Stats = thoas_stats.New() + } + serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.Stats) + if s.globalConfiguration.API.Statistics != nil { + if s.globalConfiguration.API.StatsRecorder == nil { + s.globalConfiguration.API.StatsRecorder = middlewares.NewStatsRecorder(s.globalConfiguration.API.Statistics.RecentErrors) + } + serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.StatsRecorder) + } + } + + if s.entryPoints[serverEntryPointName].Configuration.Auth != nil { + authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[serverEntryPointName].Configuration.Auth, s.tracingMiddleware) + if err != nil { + return nil, fmt.Errorf("failed to create authentication middleware: %v", err) + } + serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", serverEntryPointName))) + } + + if s.entryPoints[serverEntryPointName].Configuration.Compress { + serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{}) + } + + ipWhitelistMiddleware, err := buildIPWhiteLister( + s.entryPoints[serverEntryPointName].Configuration.WhiteList, + s.entryPoints[serverEntryPointName].Configuration.WhitelistSourceRange) + if err != nil { + return nil, fmt.Errorf("failed to create ip whitelist middleware: %v", err) + } + if ipWhitelistMiddleware != nil { + serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", serverEntryPointName))) + } + + return serverMiddlewares, nil +} + +func errorPagesPostConfig(epHandlers []*errorpages.Handler) handlerPostConfig { + return func(backendsHandlers map[string]http.Handler) error { + for _, errorPageHandler := range epHandlers { + if handler, ok := backendsHandlers[errorPageHandler.BackendName]; ok { + err := errorPageHandler.PostLoad(handler) + if err != nil { + return fmt.Errorf("failed to configure error pages for backend %s: %v", errorPageHandler.BackendName, err) + } + } else { + err := errorPageHandler.PostLoad(nil) + if err != nil { + return fmt.Errorf("failed to configure error pages for %s: %v", errorPageHandler.FallbackURL, err) + } + } + } + return nil + } +} + +func buildErrorPagesMiddleware(frontendName string, frontend *types.Frontend, backends map[string]*types.Backend, entryPointName string, providerName string) ([]*errorpages.Handler, error) { + var errorPageHandlers []*errorpages.Handler + + for errorPageName, errorPage := range frontend.Errors { + if frontend.Backend == errorPage.Backend { + log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).", + errorPageName, frontendName, errorPage.Backend) + } else if backends[errorPage.Backend] == nil { + log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.", + errorPageName, frontendName, errorPage.Backend) + } else { + errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend) + if err != nil { + return nil, fmt.Errorf("error creating error pages: %v", err) + } + + if errorPageServer, ok := backends[errorPage.Backend].Servers["error"]; ok { + errorPagesHandler.FallbackURL = errorPageServer.URL + } + + errorPageHandlers = append(errorPageHandlers, errorPagesHandler) + } + } + + return errorPageHandlers, nil +} + +func (s *Server) buildBasicAuthMiddleware(authData []string) (*mauth.Authenticator, error) { + users := types.Users{} + for _, user := range authData { + users = append(users, user) + } + + auth := &types.Auth{} + auth.Basic = &types.Basic{ + Users: users, + } + + authMiddleware, err := mauth.NewAuthenticator(auth, s.tracingMiddleware) + if err != nil { + return nil, fmt.Errorf("error creating Basic Auth: %v", err) + } + + return authMiddleware, nil +} + +func (s *Server) buildEntryPointRedirect() (map[string]negroni.Handler, error) { + redirectHandlers := map[string]negroni.Handler{} + + for entryPointName, ep := range s.entryPoints { + entryPoint := ep.Configuration + + if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint { + handler, err := s.buildRedirectHandler(entryPointName, entryPoint.Redirect) + if err != nil { + return nil, fmt.Errorf("error loading configuration for entrypoint %s: %v", entryPointName, err) + } + + handlerToUse := s.wrapNegroniHandlerWithAccessLog(handler, fmt.Sprintf("entrypoint redirect for %s", entryPointName)) + redirectHandlers[entryPointName] = handlerToUse + } + } + + return redirectHandlers, nil +} + +func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) { + // entry point redirect + if len(opt.EntryPoint) > 0 { + entryPoint := s.entryPoints[opt.EntryPoint].Configuration + if entryPoint == nil { + return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName) + } + log.Debugf("Creating entry point redirect %s -> %s", srcEntryPointName, opt.EntryPoint) + return redirect.NewEntryPointHandler(entryPoint, opt.Permanent) + } + + // regex redirect + redirection, err := redirect.NewRegexHandler(opt.Regex, opt.Replacement, opt.Permanent) + if err != nil { + return nil, err + } + log.Debugf("Creating regex redirect %s -> %s -> %s", srcEntryPointName, opt.Regex, opt.Replacement) + + return redirection, nil +} + +func buildIPWhiteLister(whiteList *types.WhiteList, wlRange []string) (*middlewares.IPWhiteLister, error) { + if whiteList != nil && + len(whiteList.SourceRange) > 0 { + return middlewares.NewIPWhiteLister(whiteList.SourceRange, whiteList.UseXForwardedFor) + } else if len(wlRange) > 0 { + return middlewares.NewIPWhiteLister(wlRange, false) + } + return nil, nil +} + +func (s *Server) wrapNegroniHandlerWithAccessLog(handler negroni.Handler, frontendName string) negroni.Handler { + if s.accessLoggerMiddleware != nil { + saveBackend := accesslog.NewSaveNegroniBackend(handler, "Træfik") + saveFrontend := accesslog.NewSaveNegroniFrontend(saveBackend, frontendName) + return saveFrontend + } + return handler +} + +func (s *Server) wrapHTTPHandlerWithAccessLog(handler http.Handler, frontendName string) http.Handler { + if s.accessLoggerMiddleware != nil { + saveBackend := accesslog.NewSaveBackend(handler, "Træfik") + saveFrontend := accesslog.NewSaveFrontend(saveBackend, frontendName) + return saveFrontend + } + return handler +} + +func buildModifyResponse(secure *secure.Secure, header *middlewares.HeaderStruct) func(res *http.Response) error { + return func(res *http.Response) error { + if secure != nil { + if err := secure.ModifyResponseHeaders(res); err != nil { + return err + } + } + + if header != nil { + if err := header.ModifyResponseHeaders(res); err != nil { + return err + } + } + return nil + } +} diff --git a/server/server_middlewares_test.go b/server/server_middlewares_test.go new file mode 100644 index 000000000..8b9651268 --- /dev/null +++ b/server/server_middlewares_test.go @@ -0,0 +1,253 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/containous/mux" + "github.com/containous/traefik/configuration" + "github.com/containous/traefik/metrics" + "github.com/containous/traefik/middlewares" + th "github.com/containous/traefik/testhelpers" + "github.com/containous/traefik/tls" + "github.com/containous/traefik/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/negroni" +) + +func TestServerEntryPointWhitelistConfig(t *testing.T) { + testCases := []struct { + desc string + entrypoint *configuration.EntryPoint + expectMiddleware bool + }{ + { + desc: "no whitelist middleware if no config on entrypoint", + entrypoint: &configuration.EntryPoint{ + Address: ":0", + ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, + }, + expectMiddleware: false, + }, + { + desc: "whitelist middleware should be added if configured on entrypoint", + entrypoint: &configuration.EntryPoint{ + Address: ":0", + WhitelistSourceRange: []string{ + "127.0.0.1/32", + }, + ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, + }, + expectMiddleware: true, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + srv := Server{ + globalConfiguration: configuration.GlobalConfiguration{}, + metricsRegistry: metrics.NewVoidRegistry(), + entryPoints: map[string]EntryPoint{ + "test": { + Configuration: test.entrypoint, + }, + }, + } + + srv.serverEntryPoints = srv.buildServerEntryPoints() + srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"]) + handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni) + + found := false + for _, handler := range handler.Handlers() { + if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) { + found = true + } + } + + if found && !test.expectMiddleware { + t.Error("ip whitelist middleware was installed even though it should not") + } + + if !found && test.expectMiddleware { + t.Error("ip whitelist middleware was not installed even though it should have") + } + }) + } +} + +func TestBuildIPWhiteLister(t *testing.T) { + testCases := []struct { + desc string + whitelistSourceRange []string + whiteList *types.WhiteList + middlewareConfigured bool + errMessage string + }{ + { + desc: "no whitelists configured", + whitelistSourceRange: nil, + middlewareConfigured: false, + errMessage: "", + }, + { + desc: "whitelists configured (deprecated)", + whitelistSourceRange: []string{ + "1.2.3.4/24", + "fe80::/16", + }, + middlewareConfigured: true, + errMessage: "", + }, + { + desc: "invalid whitelists configured (deprecated)", + whitelistSourceRange: []string{ + "foo", + }, + middlewareConfigured: false, + errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list : invalid CIDR address: foo", + }, + { + desc: "whitelists configured", + whiteList: &types.WhiteList{ + SourceRange: []string{ + "1.2.3.4/24", + "fe80::/16", + }, + UseXForwardedFor: false, + }, + middlewareConfigured: true, + errMessage: "", + }, + { + desc: "invalid whitelists configured (deprecated)", + whiteList: &types.WhiteList{ + SourceRange: []string{ + "foo", + }, + UseXForwardedFor: false, + }, + middlewareConfigured: false, + errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list : invalid CIDR address: foo", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + middleware, err := buildIPWhiteLister(test.whiteList, test.whitelistSourceRange) + + if test.errMessage != "" { + require.EqualError(t, err, test.errMessage) + } else { + assert.NoError(t, err) + + if test.middlewareConfigured { + require.NotNil(t, middleware, "not expected middleware to be configured") + } else { + require.Nil(t, middleware, "expected middleware to be configured") + } + } + }) + } +} + +func TestBuildRedirectHandler(t *testing.T) { + srv := Server{ + globalConfiguration: configuration.GlobalConfiguration{}, + entryPoints: map[string]EntryPoint{ + "http": {Configuration: &configuration.EntryPoint{Address: ":80"}}, + "https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}}, + }, + } + + testCases := []struct { + desc string + srcEntryPointName string + url string + entryPoint *configuration.EntryPoint + redirect *types.Redirect + expectedURL string + }{ + { + desc: "redirect regex", + srcEntryPointName: "http", + url: "http://foo.com", + redirect: &types.Redirect{ + Regex: `^(?:http?:\/\/)(foo)(\.com)$`, + Replacement: "https://$1{{\"bar\"}}$2", + }, + entryPoint: &configuration.EntryPoint{ + Address: ":80", + Redirect: &types.Redirect{ + Regex: `^(?:http?:\/\/)(foo)(\.com)$`, + Replacement: "https://$1{{\"bar\"}}$2", + }, + }, + expectedURL: "https://foobar.com", + }, + { + desc: "redirect entry point", + srcEntryPointName: "http", + url: "http://foo:80", + redirect: &types.Redirect{ + EntryPoint: "https", + }, + entryPoint: &configuration.EntryPoint{ + Address: ":80", + Redirect: &types.Redirect{ + EntryPoint: "https", + }, + }, + expectedURL: "https://foo:443", + }, + { + desc: "redirect entry point with regex (ignored)", + srcEntryPointName: "http", + url: "http://foo.com:80", + redirect: &types.Redirect{ + EntryPoint: "https", + Regex: `^(?:http?:\/\/)(foo)(\.com)$`, + Replacement: "https://$1{{\"bar\"}}$2", + }, + entryPoint: &configuration.EntryPoint{ + Address: ":80", + Redirect: &types.Redirect{ + EntryPoint: "https", + Regex: `^(?:http?:\/\/)(foo)(\.com)$`, + Replacement: "https://$1{{\"bar\"}}$2", + }, + }, + expectedURL: "https://foo.com:443", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect) + require.NoError(t, err) + + req := th.MustNewRequest(http.MethodGet, test.url, nil) + recorder := httptest.NewRecorder() + + rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Location", "fail") + })) + + location, err := recorder.Result().Location() + require.NoError(t, err) + assert.Equal(t, test.expectedURL, location.String()) + }) + } +} diff --git a/server/server_signals.go b/server/server_signals.go index d024a627a..fb8514c9a 100644 --- a/server/server_signals.go +++ b/server/server_signals.go @@ -22,12 +22,12 @@ func (s *Server) listenSignals() { if s.accessLoggerMiddleware != nil { if err := s.accessLoggerMiddleware.Rotate(); err != nil { - log.Errorf("Error rotating access log: %s", err) + log.Errorf("Error rotating access log: %v", err) } } if err := log.RotateFile(); err != nil { - log.Errorf("Error rotating traefik log: %s", err) + log.Errorf("Error rotating traefik log: %v", err) } } } diff --git a/server/server_test.go b/server/server_test.go index 4353687ad..f1572002f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,83 +2,22 @@ package server import ( "context" - "fmt" "net/http" "net/http/httptest" - "net/url" - "reflect" "testing" "time" "github.com/containous/flaeg" "github.com/containous/mux" "github.com/containous/traefik/configuration" - "github.com/containous/traefik/healthcheck" - "github.com/containous/traefik/metrics" "github.com/containous/traefik/middlewares" - "github.com/containous/traefik/rules" th "github.com/containous/traefik/testhelpers" - "github.com/containous/traefik/tls" "github.com/containous/traefik/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/unrolled/secure" - "github.com/urfave/negroni" - "github.com/vulcand/oxy/roundrobin" ) -// LocalhostCert is a PEM-encoded TLS cert with SAN IPs -// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. -// generated from src/crypto/tls: -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h -var ( - localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE----- -MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS -MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw -MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB -iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 -iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul -rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO -BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw -AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA -AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 -tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs -h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM -fblo6RBxUQ== ------END CERTIFICATE-----`) - - // LocalhostKey is the private key for localhostCert. - localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY----- -MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 -SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB -l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB -AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet -3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb -uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H -qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp -jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY -fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U -fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU -y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX -qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo -f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== ------END RSA PRIVATE KEY-----`) -) - -type testLoadBalancer struct{} - -func (lb *testLoadBalancer) RemoveServer(u *url.URL) error { - return nil -} - -func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error { - return nil -} - -func (lb *testLoadBalancer) Servers() []*url.URL { - return []*url.URL{} -} - func TestPrepareServerTimeouts(t *testing.T) { testCases := []struct { desc string @@ -282,559 +221,6 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c return server, stop, invokeStopChan } -func TestThrottleProviderConfigReload(t *testing.T) { - throttleDuration := 30 * time.Millisecond - publishConfig := make(chan types.ConfigMessage) - providerConfig := make(chan types.ConfigMessage) - stop := make(chan bool) - defer func() { - stop <- true - }() - - globalConfig := configuration.GlobalConfiguration{} - server := NewServer(globalConfig, nil, nil) - - go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop) - - publishedConfigCount := 0 - stopConsumeConfigs := make(chan bool) - go func() { - for { - select { - case <-stop: - return - case <-stopConsumeConfigs: - return - case <-publishConfig: - publishedConfigCount++ - } - } - }() - - // publish 5 new configs, one new config each 10 milliseconds - for i := 0; i < 5; i++ { - providerConfig <- types.ConfigMessage{} - time.Sleep(10 * time.Millisecond) - } - - // after 50 milliseconds 5 new configs were published - // with a throttle duration of 30 milliseconds this means, we should have received 2 new configs - assert.Equal(t, 2, publishedConfigCount, "times configs were published") - - stopConsumeConfigs <- true - - select { - case <-publishConfig: - // There should be exactly one more message that we receive after ~60 milliseconds since the start of the test. - select { - case <-publishConfig: - t.Error("extra config publication found") - case <-time.After(100 * time.Millisecond): - return - } - case <-time.After(100 * time.Millisecond): - t.Error("Last config was not published in time") - } -} - -func TestServerMultipleFrontendRules(t *testing.T) { - testCases := []struct { - expression string - requestURL string - expectedURL string - }{ - { - expression: "Host:foo.bar", - requestURL: "http://foo.bar", - expectedURL: "http://foo.bar", - }, - { - expression: "PathPrefix:/management;ReplacePath:/health", - requestURL: "http://foo.bar/management", - expectedURL: "http://foo.bar/health", - }, - { - expression: "Host:foo.bar;AddPrefix:/blah", - requestURL: "http://foo.bar/baz", - expectedURL: "http://foo.bar/blah/baz", - }, - { - expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}", - requestURL: "http://foo.bar/one/some/12345/four", - expectedURL: "http://foo.bar/four", - }, - { - expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero", - requestURL: "http://foo.bar/one/some/12345/four", - expectedURL: "http://foo.bar/zero/four", - }, - { - expression: "AddPrefix:/blah;ReplacePath:/baz", - requestURL: "http://foo.bar/hello", - expectedURL: "http://foo.bar/baz", - }, - { - expression: "PathPrefixStrip:/management;ReplacePath:/health", - requestURL: "http://foo.bar/management", - expectedURL: "http://foo.bar/health", - }, - } - - for _, test := range testCases { - test := test - t.Run(test.expression, func(t *testing.T) { - t.Parallel() - - router := mux.NewRouter() - route := router.NewRoute() - serverRoute := &types.ServerRoute{Route: route} - rules := &rules.Rules{Route: serverRoute} - - expression := test.expression - routeResult, err := rules.Parse(expression) - - if err != nil { - t.Fatalf("Error while building route for %s: %+v", expression, err) - } - - request := th.MustNewRequest(http.MethodGet, test.requestURL, nil) - routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - - if !routeMatch { - t.Fatalf("Rule %s doesn't match", expression) - } - - server := new(Server) - - server.wireFrontendBackend(serverRoute, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, test.expectedURL, r.URL.String(), "URL") - })) - serverRoute.Route.GetHandler().ServeHTTP(nil, request) - }) - } -} - -func TestServerLoadConfigHealthCheckOptions(t *testing.T) { - healthChecks := []*types.HealthCheck{ - nil, - { - Path: "/path", - }, - } - - for _, lbMethod := range []string{"Wrr", "Drr"} { - for _, healthCheck := range healthChecks { - t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) { - globalConfig := configuration.GlobalConfiguration{ - HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)}, - } - entryPoints := map[string]EntryPoint{ - "http": { - Configuration: &configuration.EntryPoint{ - ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, - }, - }, - } - - dynamicConfigs := types.Configurations{ - "config": &types.Configuration{ - Frontends: map[string]*types.Frontend{ - "frontend": { - EntryPoints: []string{"http"}, - Backend: "backend", - }, - }, - Backends: map[string]*types.Backend{ - "backend": { - Servers: map[string]types.Server{ - "server": { - URL: "http://localhost", - }, - }, - LoadBalancer: &types.LoadBalancer{ - Method: lbMethod, - }, - HealthCheck: healthCheck, - }, - }, - TLS: []*tls.Configuration{ - { - Certificate: &tls.Certificate{ - CertFile: localhostCert, - KeyFile: localhostKey, - }, - EntryPoints: []string{"http"}, - }, - }, - }, - } - - srv := NewServer(globalConfig, nil, entryPoints) - - _, err := srv.loadConfig(dynamicConfigs, globalConfig) - require.NoError(t, err) - - expectedNumHealthCheckBackends := 0 - if healthCheck != nil { - expectedNumHealthCheckBackends = 1 - } - assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends") - }) - } - } -} - -func TestServerParseHealthCheckOptions(t *testing.T) { - lb := &testLoadBalancer{} - globalInterval := 15 * time.Second - - testCases := []struct { - desc string - hc *types.HealthCheck - expectedOpts *healthcheck.Options - }{ - { - desc: "nil health check", - hc: nil, - expectedOpts: nil, - }, - { - desc: "empty path", - hc: &types.HealthCheck{ - Path: "", - }, - expectedOpts: nil, - }, - { - desc: "unparseable interval", - hc: &types.HealthCheck{ - Path: "/path", - Interval: "unparseable", - }, - expectedOpts: &healthcheck.Options{ - Path: "/path", - Interval: globalInterval, - LB: lb, - }, - }, - { - desc: "sub-zero interval", - hc: &types.HealthCheck{ - Path: "/path", - Interval: "-42s", - }, - expectedOpts: &healthcheck.Options{ - Path: "/path", - Interval: globalInterval, - LB: lb, - }, - }, - { - desc: "parseable interval", - hc: &types.HealthCheck{ - Path: "/path", - Interval: "5m", - }, - expectedOpts: &healthcheck.Options{ - Path: "/path", - Interval: 5 * time.Minute, - LB: lb, - }, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - opts := parseHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)}) - assert.Equal(t, test.expectedOpts, opts, "health check options") - }) - } -} - -func TestBuildIPWhiteLister(t *testing.T) { - testCases := []struct { - desc string - whitelistSourceRange []string - whiteList *types.WhiteList - middlewareConfigured bool - errMessage string - }{ - { - desc: "no whitelists configured", - whitelistSourceRange: nil, - middlewareConfigured: false, - errMessage: "", - }, - { - desc: "whitelists configured (deprecated)", - whitelistSourceRange: []string{ - "1.2.3.4/24", - "fe80::/16", - }, - middlewareConfigured: true, - errMessage: "", - }, - { - desc: "invalid whitelists configured (deprecated)", - whitelistSourceRange: []string{ - "foo", - }, - middlewareConfigured: false, - errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list : invalid CIDR address: foo", - }, - { - desc: "whitelists configured", - whiteList: &types.WhiteList{ - SourceRange: []string{ - "1.2.3.4/24", - "fe80::/16", - }, - UseXForwardedFor: false, - }, - middlewareConfigured: true, - errMessage: "", - }, - { - desc: "invalid whitelists configured (deprecated)", - whiteList: &types.WhiteList{ - SourceRange: []string{ - "foo", - }, - UseXForwardedFor: false, - }, - middlewareConfigured: false, - errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list : invalid CIDR address: foo", - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - middleware, err := buildIPWhiteLister(test.whiteList, test.whitelistSourceRange) - - if test.errMessage != "" { - require.EqualError(t, err, test.errMessage) - } else { - assert.NoError(t, err) - - if test.middlewareConfigured { - require.NotNil(t, middleware, "not expected middleware to be configured") - } else { - require.Nil(t, middleware, "expected middleware to be configured") - } - } - }) - } -} - -func TestServerLoadConfigEmptyBasicAuth(t *testing.T) { - globalConfig := configuration.GlobalConfiguration{ - EntryPoints: configuration.EntryPoints{ - "http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}, - }, - } - - dynamicConfigs := types.Configurations{ - "config": &types.Configuration{ - Frontends: map[string]*types.Frontend{ - "frontend": { - EntryPoints: []string{"http"}, - Backend: "backend", - BasicAuth: []string{""}, - }, - }, - Backends: map[string]*types.Backend{ - "backend": { - Servers: map[string]types.Server{ - "server": { - URL: "http://localhost", - }, - }, - LoadBalancer: &types.LoadBalancer{ - Method: "Wrr", - }, - }, - }, - }, - } - - srv := NewServer(globalConfig, nil, nil) - _, err := srv.loadConfig(dynamicConfigs, globalConfig) - require.NoError(t, err) -} - -func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) { - globalConfig := configuration.GlobalConfiguration{ - DefaultEntryPoints: []string{"http", "https"}, - } - entryPoints := map[string]EntryPoint{ - "https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}}, - "http": {Configuration: &configuration.EntryPoint{}}, - } - - dynamicConfigs := types.Configurations{ - "config": &types.Configuration{ - TLS: []*tls.Configuration{ - { - Certificate: &tls.Certificate{ - CertFile: localhostCert, - KeyFile: localhostKey, - }, - }, - }, - }, - } - - srv := NewServer(globalConfig, nil, entryPoints) - if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil { - t.Fatalf("got error: %s", err) - } else if mapEntryPoints["https"].certs.Get() == nil { - t.Fatal("got error: https entryPoint must have TLS certificates.") - } -} - -func TestConfigureBackends(t *testing.T) { - validMethod := "Drr" - defaultMethod := "wrr" - - testCases := []struct { - desc string - lb *types.LoadBalancer - expectedMethod string - expectedStickiness *types.Stickiness - }{ - { - desc: "valid load balancer method with sticky enabled", - lb: &types.LoadBalancer{ - Method: validMethod, - Stickiness: &types.Stickiness{}, - }, - expectedMethod: validMethod, - expectedStickiness: &types.Stickiness{}, - }, - { - desc: "valid load balancer method with sticky disabled", - lb: &types.LoadBalancer{ - Method: validMethod, - Stickiness: nil, - }, - expectedMethod: validMethod, - }, - { - desc: "invalid load balancer method with sticky enabled", - lb: &types.LoadBalancer{ - Method: "Invalid", - Stickiness: &types.Stickiness{}, - }, - expectedMethod: defaultMethod, - expectedStickiness: &types.Stickiness{}, - }, - { - desc: "invalid load balancer method with sticky disabled", - lb: &types.LoadBalancer{ - Method: "Invalid", - Stickiness: nil, - }, - expectedMethod: defaultMethod, - }, - { - desc: "missing load balancer", - lb: nil, - expectedMethod: defaultMethod, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - backend := &types.Backend{ - LoadBalancer: test.lb, - } - - configureBackends(map[string]*types.Backend{ - "backend": backend, - }) - - expected := types.LoadBalancer{ - Method: test.expectedMethod, - Stickiness: test.expectedStickiness, - } - - assert.Equal(t, expected, *backend.LoadBalancer) - }) - } -} - -func TestServerEntryPointWhitelistConfig(t *testing.T) { - testCases := []struct { - desc string - entrypoint *configuration.EntryPoint - expectMiddleware bool - }{ - { - desc: "no whitelist middleware if no config on entrypoint", - entrypoint: &configuration.EntryPoint{ - Address: ":0", - ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, - }, - expectMiddleware: false, - }, - { - desc: "whitelist middleware should be added if configured on entrypoint", - entrypoint: &configuration.EntryPoint{ - Address: ":0", - WhitelistSourceRange: []string{ - "127.0.0.1/32", - }, - ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, - }, - expectMiddleware: true, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - srv := Server{ - globalConfiguration: configuration.GlobalConfiguration{}, - metricsRegistry: metrics.NewVoidRegistry(), - entryPoints: map[string]EntryPoint{ - "test": { - Configuration: test.entrypoint, - }, - }, - } - - srv.serverEntryPoints = srv.buildEntryPoints() - srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"]) - handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni) - - found := false - for _, handler := range handler.Handlers() { - if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) { - found = true - } - } - - if found && !test.expectMiddleware { - t.Error("ip whitelist middleware was installed even though it should not") - } - - if !found && test.expectMiddleware { - t.Error("ip whitelist middleware was not installed even though it should have") - } - }) - } -} - func TestServerResponseEmptyBackend(t *testing.T) { const requestPath = "/path" const routeRule = "Path:" + requestPath @@ -962,157 +348,6 @@ func TestServerResponseEmptyBackend(t *testing.T) { } } -func TestReuseBackend(t *testing.T) { - testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusOK) - })) - defer testServer.Close() - - globalConfig := configuration.GlobalConfiguration{ - DefaultEntryPoints: []string{"http"}, - } - - entryPoints := map[string]EntryPoint{ - "http": {Configuration: &configuration.EntryPoint{ - ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}, - }}, - } - - dynamicConfigs := types.Configurations{ - "config": th.BuildConfiguration( - th.WithFrontends( - th.WithFrontend("backend", - th.WithFrontendName("frontend0"), - th.WithEntryPoints("http"), - th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))), - th.WithFrontend("backend", - th.WithFrontendName("frontend1"), - th.WithEntryPoints("http"), - th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")), - th.WithBasicAuth("foo", "bar")), - ), - th.WithBackends(th.WithBackendNew("backend", - th.WithLBMethod("wrr"), - th.WithServersNew(th.WithServerNew(testServer.URL))), - ), - ), - } - - srv := NewServer(globalConfig, nil, entryPoints) - - serverEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig) - if err != nil { - t.Fatalf("error loading config: %s", err) - } - - // Test that the /ok path returns a status 200. - responseRecorderOk := &httptest.ResponseRecorder{} - requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil) - serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk) - - assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code") - - // Test that the /unauthorized path returns a 401 because of - // the basic authentication defined on the frontend. - responseRecorderUnauthorized := &httptest.ResponseRecorder{} - requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil) - serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized) - - assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code") -} - -func TestBuildRedirectHandler(t *testing.T) { - srv := Server{ - globalConfiguration: configuration.GlobalConfiguration{}, - entryPoints: map[string]EntryPoint{ - "http": {Configuration: &configuration.EntryPoint{Address: ":80"}}, - "https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}}, - }, - } - - testCases := []struct { - desc string - srcEntryPointName string - url string - entryPoint *configuration.EntryPoint - redirect *types.Redirect - expectedURL string - }{ - { - desc: "redirect regex", - srcEntryPointName: "http", - url: "http://foo.com", - redirect: &types.Redirect{ - Regex: `^(?:http?:\/\/)(foo)(\.com)$`, - Replacement: "https://$1{{\"bar\"}}$2", - }, - entryPoint: &configuration.EntryPoint{ - Address: ":80", - Redirect: &types.Redirect{ - Regex: `^(?:http?:\/\/)(foo)(\.com)$`, - Replacement: "https://$1{{\"bar\"}}$2", - }, - }, - expectedURL: "https://foobar.com", - }, - { - desc: "redirect entry point", - srcEntryPointName: "http", - url: "http://foo:80", - redirect: &types.Redirect{ - EntryPoint: "https", - }, - entryPoint: &configuration.EntryPoint{ - Address: ":80", - Redirect: &types.Redirect{ - EntryPoint: "https", - }, - }, - expectedURL: "https://foo:443", - }, - { - desc: "redirect entry point with regex (ignored)", - srcEntryPointName: "http", - url: "http://foo.com:80", - redirect: &types.Redirect{ - EntryPoint: "https", - Regex: `^(?:http?:\/\/)(foo)(\.com)$`, - Replacement: "https://$1{{\"bar\"}}$2", - }, - entryPoint: &configuration.EntryPoint{ - Address: ":80", - Redirect: &types.Redirect{ - EntryPoint: "https", - Regex: `^(?:http?:\/\/)(foo)(\.com)$`, - Replacement: "https://$1{{\"bar\"}}$2", - }, - }, - expectedURL: "https://foo.com:443", - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect) - require.NoError(t, err) - - req := th.MustNewRequest(http.MethodGet, test.url, nil) - recorder := httptest.NewRecorder() - - rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Location", "fail") - })) - - location, err := recorder.Result().Location() - require.NoError(t, err) - assert.Equal(t, test.expectedURL, location.String()) - }) - } -} - type mockContext struct { headers http.Header }