Extreme Makeover: server refactoring
This commit is contained in:
parent
bddb4cc33c
commit
eac20d61df
19 changed files with 2356 additions and 1965 deletions
|
@ -19,12 +19,18 @@ import (
|
|||
var singleton *HealthCheck
|
||||
var once sync.Once
|
||||
|
||||
// GetHealthCheck returns the health check which is guaranteed to be a singleton.
|
||||
func GetHealthCheck(metrics metricsRegistry) *HealthCheck {
|
||||
once.Do(func() {
|
||||
singleton = newHealthCheck(metrics)
|
||||
})
|
||||
return singleton
|
||||
// BalancerHandler includes functionality for load-balancing management.
|
||||
type BalancerHandler interface {
|
||||
ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
Servers() []*url.URL
|
||||
RemoveServer(u *url.URL) error
|
||||
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
|
||||
}
|
||||
|
||||
// metricsRegistry is a local interface in the health check package, exposing only the required metrics
|
||||
// necessary for the health check package. This makes it easier for the tests.
|
||||
type metricsRegistry interface {
|
||||
BackendServerUpGauge() metrics.Gauge
|
||||
}
|
||||
|
||||
// Options are the public health check options.
|
||||
|
@ -36,59 +42,59 @@ type Options struct {
|
|||
Port int
|
||||
Transport http.RoundTripper
|
||||
Interval time.Duration
|
||||
LB LoadBalancer
|
||||
LB BalancerHandler
|
||||
}
|
||||
|
||||
func (opt Options) String() string {
|
||||
return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Port: %d Interval: %s]", opt.Hostname, opt.Headers, opt.Path, opt.Port, opt.Interval)
|
||||
}
|
||||
|
||||
// BackendHealthCheck HealthCheck configuration for a backend
|
||||
type BackendHealthCheck struct {
|
||||
// BackendConfig HealthCheck configuration for a backend
|
||||
type BackendConfig struct {
|
||||
Options
|
||||
name string
|
||||
disabledURLs []*url.URL
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) {
|
||||
u := &url.URL{}
|
||||
*u = *serverURL
|
||||
|
||||
if len(b.Scheme) > 0 {
|
||||
u.Scheme = b.Scheme
|
||||
}
|
||||
|
||||
if b.Port != 0 {
|
||||
u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port))
|
||||
}
|
||||
|
||||
u.Path += b.Path
|
||||
|
||||
return http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
}
|
||||
|
||||
// this function adds additional http headers and hostname to http.request
|
||||
func (b *BackendConfig) addHeadersAndHost(req *http.Request) *http.Request {
|
||||
if b.Options.Hostname != "" {
|
||||
req.Host = b.Options.Hostname
|
||||
}
|
||||
|
||||
for k, v := range b.Options.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// HealthCheck struct
|
||||
type HealthCheck struct {
|
||||
Backends map[string]*BackendHealthCheck
|
||||
Backends map[string]*BackendConfig
|
||||
metrics metricsRegistry
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// LoadBalancer includes functionality for load-balancing management.
|
||||
type LoadBalancer interface {
|
||||
RemoveServer(u *url.URL) error
|
||||
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
|
||||
Servers() []*url.URL
|
||||
}
|
||||
|
||||
func newHealthCheck(metrics metricsRegistry) *HealthCheck {
|
||||
return &HealthCheck{
|
||||
Backends: make(map[string]*BackendHealthCheck),
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// metricsRegistry is a local interface in the health check package, exposing only the required metrics
|
||||
// necessary for the health check package. This makes it easier for the tests.
|
||||
type metricsRegistry interface {
|
||||
BackendServerUpGauge() metrics.Gauge
|
||||
}
|
||||
|
||||
// NewBackendHealthCheck Instantiate a new BackendHealthCheck
|
||||
func NewBackendHealthCheck(options Options, backendName string) *BackendHealthCheck {
|
||||
return &BackendHealthCheck{
|
||||
Options: options,
|
||||
name: backendName,
|
||||
requestTimeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBackendsConfiguration set backends configuration
|
||||
func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendHealthCheck) {
|
||||
func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) {
|
||||
hc.Backends = backends
|
||||
if hc.cancel != nil {
|
||||
hc.cancel()
|
||||
|
@ -104,7 +110,7 @@ func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backe
|
|||
}
|
||||
}
|
||||
|
||||
func (hc *HealthCheck) execute(ctx context.Context, backend *BackendHealthCheck) {
|
||||
func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) {
|
||||
log.Debugf("Initial health check for backend: %q", backend.name)
|
||||
hc.checkBackend(backend)
|
||||
ticker := time.NewTicker(backend.Interval)
|
||||
|
@ -121,7 +127,7 @@ func (hc *HealthCheck) execute(ctx context.Context, backend *BackendHealthCheck)
|
|||
}
|
||||
}
|
||||
|
||||
func (hc *HealthCheck) checkBackend(backend *BackendHealthCheck) {
|
||||
func (hc *HealthCheck) checkBackend(backend *BackendConfig) {
|
||||
enabledURLs := backend.LB.Servers()
|
||||
var newDisabledURLs []*url.URL
|
||||
for _, url := range backend.disabledURLs {
|
||||
|
@ -152,38 +158,33 @@ func (hc *HealthCheck) checkBackend(backend *BackendHealthCheck) {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *BackendHealthCheck) newRequest(serverURL *url.URL) (*http.Request, error) {
|
||||
u := &url.URL{}
|
||||
*u = *serverURL
|
||||
|
||||
if len(b.Scheme) > 0 {
|
||||
u.Scheme = b.Scheme
|
||||
}
|
||||
|
||||
if b.Port != 0 {
|
||||
u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port))
|
||||
}
|
||||
|
||||
u.Path += b.Path
|
||||
|
||||
return http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
// GetHealthCheck returns the health check which is guaranteed to be a singleton.
|
||||
func GetHealthCheck(metrics metricsRegistry) *HealthCheck {
|
||||
once.Do(func() {
|
||||
singleton = newHealthCheck(metrics)
|
||||
})
|
||||
return singleton
|
||||
}
|
||||
|
||||
// this function adds additional http headers and hostname to http.request
|
||||
func (b *BackendHealthCheck) addHeadersAndHost(req *http.Request) *http.Request {
|
||||
if b.Options.Hostname != "" {
|
||||
req.Host = b.Options.Hostname
|
||||
func newHealthCheck(metrics metricsRegistry) *HealthCheck {
|
||||
return &HealthCheck{
|
||||
Backends: make(map[string]*BackendConfig),
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range b.Options.Headers {
|
||||
req.Header.Set(k, v)
|
||||
// NewBackendConfig Instantiate a new BackendConfig
|
||||
func NewBackendConfig(options Options, backendName string) *BackendConfig {
|
||||
return &BackendConfig{
|
||||
Options: options,
|
||||
name: backendName,
|
||||
requestTimeout: 5 * time.Second,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// checkHealth returns a nil error in case it was successful and otherwise
|
||||
// a non-nil error with a meaningful description why the health check failed.
|
||||
func checkHealth(serverURL *url.URL, backend *BackendHealthCheck) error {
|
||||
func checkHealth(serverURL *url.URL, backend *BackendConfig) error {
|
||||
req, err := backend.newRequest(serverURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create HTTP request: %s", err)
|
||||
|
|
|
@ -102,7 +102,7 @@ func TestSetBackendsConfiguration(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}}
|
||||
backend := NewBackendHealthCheck(Options{
|
||||
backend := NewBackendConfig(Options{
|
||||
Path: "/path",
|
||||
Interval: healthCheckInterval,
|
||||
LB: lb,
|
||||
|
@ -117,7 +117,7 @@ func TestSetBackendsConfiguration(t *testing.T) {
|
|||
|
||||
collectingMetrics := testhelpers.NewCollectingHealthCheckMetrics()
|
||||
check := HealthCheck{
|
||||
Backends: make(map[string]*BackendHealthCheck),
|
||||
Backends: make(map[string]*BackendConfig),
|
||||
metrics: collectingMetrics,
|
||||
}
|
||||
|
||||
|
@ -209,7 +209,7 @@ func TestNewRequest(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := NewBackendHealthCheck(test.options, "backendName")
|
||||
backend := NewBackendConfig(test.options, "backendName")
|
||||
|
||||
u, err := url.Parse(test.serverURL)
|
||||
require.NoError(t, err)
|
||||
|
@ -279,7 +279,7 @@ func TestAddHeadersAndHost(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := NewBackendHealthCheck(test.options, "backendName")
|
||||
backend := NewBackendConfig(test.options, "backendName")
|
||||
|
||||
u, err := url.Parse(test.serverURL)
|
||||
require.NoError(t, err)
|
||||
|
@ -305,6 +305,10 @@ type testLoadBalancer struct {
|
|||
servers []*url.URL
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
lb.Lock()
|
||||
defer lb.Unlock()
|
||||
|
|
|
@ -102,7 +102,7 @@ func (s *AccessLogSuite) TestAccessLogAuthFrontend(c *check.C) {
|
|||
formatOnly: false,
|
||||
code: "401",
|
||||
user: "-",
|
||||
frontendName: "Auth for frontend-Host-frontend-auth-docker-local",
|
||||
frontendName: "Basic Auth for frontend-Host-frontend-auth-docker-local",
|
||||
backendURL: "/",
|
||||
},
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ func (s *AccessLogSuite) TestAccessLogEntrypointRedirect(c *check.C) {
|
|||
formatOnly: false,
|
||||
code: "302",
|
||||
user: "-",
|
||||
frontendName: "entrypoint redirect for frontend-",
|
||||
frontendName: "entrypoint redirect for httpRedirect",
|
||||
backendURL: "/",
|
||||
},
|
||||
{
|
||||
|
|
|
@ -31,38 +31,43 @@ func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing
|
|||
if authConfig == nil {
|
||||
return nil, fmt.Errorf("error creating Authenticator: auth is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
authenticator := Authenticator{}
|
||||
tracingAuthenticator := tracingAuthenticator{}
|
||||
authenticator := &Authenticator{}
|
||||
tracingAuth := tracingAuthenticator{}
|
||||
|
||||
if authConfig.Basic != nil {
|
||||
authenticator.users, err = parserBasicUsers(authConfig.Basic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
basicAuth := goauth.NewBasicAuthenticator("traefik", authenticator.secretBasic)
|
||||
tracingAuthenticator.handler = createAuthBasicHandler(basicAuth, authConfig)
|
||||
tracingAuthenticator.name = "Auth Basic"
|
||||
tracingAuthenticator.clientSpanKind = false
|
||||
tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig)
|
||||
tracingAuth.name = "Auth Basic"
|
||||
tracingAuth.clientSpanKind = false
|
||||
} else if authConfig.Digest != nil {
|
||||
authenticator.users, err = parserDigestUsers(authConfig.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
|
||||
tracingAuthenticator.handler = createAuthDigestHandler(digestAuth, authConfig)
|
||||
tracingAuthenticator.name = "Auth Digest"
|
||||
tracingAuthenticator.clientSpanKind = false
|
||||
tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig)
|
||||
tracingAuth.name = "Auth Digest"
|
||||
tracingAuth.clientSpanKind = false
|
||||
} else if authConfig.Forward != nil {
|
||||
tracingAuthenticator.handler = createAuthForwardHandler(authConfig)
|
||||
tracingAuthenticator.name = "Auth Forward"
|
||||
tracingAuthenticator.clientSpanKind = true
|
||||
tracingAuth.handler = createAuthForwardHandler(authConfig)
|
||||
tracingAuth.name = "Auth Forward"
|
||||
tracingAuth.clientSpanKind = true
|
||||
}
|
||||
|
||||
if tracingMiddleware != nil {
|
||||
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuthenticator.name, tracingAuthenticator.handler, tracingAuthenticator.clientSpanKind)
|
||||
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind)
|
||||
} else {
|
||||
authenticator.handler = tracingAuthenticator.handler
|
||||
authenticator.handler = tracingAuth.handler
|
||||
}
|
||||
return &authenticator, nil
|
||||
return authenticator, nil
|
||||
}
|
||||
|
||||
func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc {
|
||||
|
|
|
@ -23,14 +23,14 @@ func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker
|
|||
|
||||
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
|
||||
func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
|
||||
return cbreaker.Fallback(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tracing.LogEventf(r, "blocked by circuitbreaker (%q)", expression)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
}))
|
||||
return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression)
|
||||
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
}))
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
cb.circuitBreaker.ServeHTTP(rw, r)
|
||||
}
|
||||
|
|
|
@ -10,19 +10,18 @@ import (
|
|||
// has at least one active Server in respect to the healthchecks and if this
|
||||
// is not the case, it will stop the middleware chain and respond with 503.
|
||||
type EmptyBackendHandler struct {
|
||||
lb healthcheck.LoadBalancer
|
||||
next http.Handler
|
||||
next healthcheck.BalancerHandler
|
||||
}
|
||||
|
||||
// NewEmptyBackendHandler creates a new EmptyBackendHandler instance.
|
||||
func NewEmptyBackendHandler(lb healthcheck.LoadBalancer, next http.Handler) *EmptyBackendHandler {
|
||||
return &EmptyBackendHandler{lb: lb, next: next}
|
||||
func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler {
|
||||
return &EmptyBackendHandler{next: lb}
|
||||
}
|
||||
|
||||
// ServeHTTP responds with 503 when there is no active Server and otherwise
|
||||
// invokes the next handler in the middleware chain.
|
||||
func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
if len(h.lb.Servers()) == 0 {
|
||||
if len(h.next.Servers()) == 0 {
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
} else {
|
||||
|
|
|
@ -32,10 +32,7 @@ func TestEmptyBackendHandler(t *testing.T) {
|
|||
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer}, nextHandler)
|
||||
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
@ -53,12 +50,8 @@ type healthCheckLoadBalancer struct {
|
|||
amountServer int
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
|
||||
|
@ -68,3 +61,23 @@ func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
|
|||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) Next() http.Handler {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,20 +6,6 @@ import (
|
|||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// NegroniHandlerWrapper is used to wrap negroni handler middleware
|
||||
type NegroniHandlerWrapper struct {
|
||||
name string
|
||||
next negroni.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
// HTTPHandlerWrapper is used to wrap http handler middleware
|
||||
type HTTPHandlerWrapper struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
// NewNegroniHandlerWrapper return a negroni.Handler struct
|
||||
func (t *Tracing) NewNegroniHandlerWrapper(name string, handler negroni.Handler, clientSpanKind bool) negroni.Handler {
|
||||
if t.IsEnabled() && handler != nil {
|
||||
|
@ -44,6 +30,13 @@ func (t *Tracing) NewHTTPHandlerWrapper(name string, handler http.Handler, clien
|
|||
return handler
|
||||
}
|
||||
|
||||
// NegroniHandlerWrapper is used to wrap negroni handler middleware
|
||||
type NegroniHandlerWrapper struct {
|
||||
name string
|
||||
next negroni.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
var finish func()
|
||||
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
|
||||
|
@ -54,6 +47,13 @@ func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
|
|||
}
|
||||
}
|
||||
|
||||
// HTTPHandlerWrapper is used to wrap http handler middleware
|
||||
type HTTPHandlerWrapper struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
func (t *HTTPHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
var finish func()
|
||||
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func notFoundHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}
|
|
@ -2,7 +2,7 @@ package server
|
|||
|
||||
import "sync"
|
||||
|
||||
const bufferPoolSize int = 32 * 1024
|
||||
const bufferPoolSize = 32 * 1024
|
||||
|
||||
func newBufferPool() *bufferPool {
|
||||
return &bufferPool{
|
||||
|
|
1134
server/server.go
1134
server/server.go
File diff suppressed because it is too large
Load diff
581
server/server_configuration.go
Normal file
581
server/server_configuration.go
Normal file
|
@ -0,0 +1,581 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/rules"
|
||||
"github.com/containous/traefik/safe"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/eapache/channels"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// loadConfiguration manages dynamically frontends, backends and TLS configurations
|
||||
func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
|
||||
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
|
||||
|
||||
// Copy configurations to new map so we don't change current if LoadConfig fails
|
||||
newConfigurations := make(types.Configurations)
|
||||
for k, v := range currentConfigurations {
|
||||
newConfigurations[k] = v
|
||||
}
|
||||
newConfigurations[configMsg.ProviderName] = configMsg.Configuration
|
||||
|
||||
s.metricsRegistry.ConfigReloadsCounter().Add(1)
|
||||
|
||||
newServerEntryPoints, err := s.loadConfig(newConfigurations, s.globalConfiguration)
|
||||
if err != nil {
|
||||
s.metricsRegistry.ConfigReloadsFailureCounter().Add(1)
|
||||
s.metricsRegistry.LastConfigReloadFailureGauge().Set(float64(time.Now().Unix()))
|
||||
log.Error("Error loading new configuration, aborted ", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
|
||||
|
||||
for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
|
||||
s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
|
||||
|
||||
if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil {
|
||||
if newServerEntryPoint.certs.Get() != nil {
|
||||
log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName)
|
||||
}
|
||||
} else {
|
||||
s.serverEntryPoints[newServerEntryPointName].certs.Set(newServerEntryPoint.certs.Get())
|
||||
}
|
||||
log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
|
||||
}
|
||||
|
||||
s.currentConfigurations.Set(newConfigurations)
|
||||
|
||||
for _, listener := range s.configurationListeners {
|
||||
listener(*configMsg.Configuration)
|
||||
}
|
||||
|
||||
s.postLoadConfiguration()
|
||||
}
|
||||
|
||||
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
|
||||
// provider configurations.
|
||||
func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]*serverEntryPoint, error) {
|
||||
redirectHandlers, err := s.buildEntryPointRedirect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverEntryPoints := s.buildServerEntryPoints()
|
||||
errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{})
|
||||
|
||||
backendsHandlers := map[string]http.Handler{}
|
||||
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
|
||||
|
||||
var postConfigs []handlerPostConfig
|
||||
|
||||
for providerName, config := range configurations {
|
||||
frontendNames := sortedFrontendNamesForConfig(config)
|
||||
|
||||
for _, frontendName := range frontendNames {
|
||||
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
|
||||
redirectHandlers, serverEntryPoints, errorHandler,
|
||||
backendsHandlers, backendsHealthCheck)
|
||||
if err != nil {
|
||||
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
|
||||
}
|
||||
|
||||
if len(frontendPostConfigs) > 0 {
|
||||
postConfigs = append(postConfigs, frontendPostConfigs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, postConfig := range postConfigs {
|
||||
err := postConfig(backendsHandlers)
|
||||
if err != nil {
|
||||
log.Errorf("middleware post configuration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck)
|
||||
|
||||
// Get new certificates list sorted per entrypoints
|
||||
// Update certificates
|
||||
entryPointsCertificates, err := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints)
|
||||
// FIXME error management
|
||||
|
||||
// Sort routes and update certificates
|
||||
for serverEntryPointName, serverEntryPoint := range serverEntryPoints {
|
||||
serverEntryPoint.httpRouter.GetHandler().SortRoutes()
|
||||
if _, exists := entryPointsCertificates[serverEntryPointName]; exists {
|
||||
serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName])
|
||||
}
|
||||
}
|
||||
|
||||
return serverEntryPoints, err
|
||||
}
|
||||
|
||||
func (s *Server) loadFrontendConfig(
|
||||
providerName string, frontendName string, config *types.Configuration,
|
||||
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler,
|
||||
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
|
||||
) ([]handlerPostConfig, error) {
|
||||
|
||||
frontend := config.Frontends[frontendName]
|
||||
|
||||
if len(frontend.EntryPoints) == 0 {
|
||||
return nil, fmt.Errorf("no entrypoint defined for frontend %s", frontendName)
|
||||
}
|
||||
|
||||
backend := config.Backends[frontend.Backend]
|
||||
if backend == nil {
|
||||
return nil, fmt.Errorf("undefined backend '%s' for frontend %s", frontend.Backend, frontendName)
|
||||
}
|
||||
|
||||
frontendHash, err := frontend.Hash()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating hash value for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
var postConfigs []handlerPostConfig
|
||||
|
||||
for _, entryPointName := range frontend.EntryPoints {
|
||||
log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName)
|
||||
|
||||
entryPoint := s.entryPoints[entryPointName].Configuration
|
||||
|
||||
if backendsHandlers[entryPointName+providerName+frontendHash] == nil {
|
||||
log.Debugf("Creating backend %s", frontend.Backend)
|
||||
|
||||
handlers, responseModifier, postConfig, err := s.buildMiddlewares(frontendName, frontend, config.Backends, entryPointName, entryPoint, providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if postConfig != nil {
|
||||
postConfigs = append(postConfigs, postConfig)
|
||||
}
|
||||
|
||||
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
lb, healthCheckConfig, err := s.buildBalancerMiddlewares(frontendName, frontend, backend, fwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if healthCheckConfig != nil {
|
||||
backendsHealthCheck[entryPointName+providerName+frontendHash] = healthCheckConfig
|
||||
}
|
||||
|
||||
n := negroni.New()
|
||||
|
||||
if _, exist := redirectHandlers[entryPointName]; exist {
|
||||
n.Use(redirectHandlers[entryPointName])
|
||||
}
|
||||
|
||||
for _, handler := range handlers {
|
||||
n.Use(handler)
|
||||
}
|
||||
|
||||
n.UseHandler(lb)
|
||||
|
||||
backendsHandlers[entryPointName+providerName+frontendHash] = n
|
||||
} else {
|
||||
log.Debugf("Reusing backend %s [%s - %s - %s - %s]",
|
||||
frontend.Backend, entryPointName, providerName, frontendName, frontendHash)
|
||||
}
|
||||
|
||||
serverRoute, err := buildServerRoute(serverEntryPoints[entryPointName], frontendName, frontend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handler := buildMatcherMiddlewares(serverRoute, backendsHandlers[entryPointName+providerName+frontendHash])
|
||||
serverRoute.Route.Handler(handler)
|
||||
|
||||
err = serverRoute.Route.GetError()
|
||||
if err != nil {
|
||||
// FIXME error management
|
||||
log.Errorf("Error building route: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return postConfigs, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
|
||||
frontendName string, frontend *types.Frontend,
|
||||
errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) {
|
||||
|
||||
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create RoundTripper for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
rewriter, err := NewHeaderRewriter(entryPoint.ForwardedHeaders.TrustedIPs, entryPoint.ForwardedHeaders.Insecure)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating rewriter for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
var fwd http.Handler
|
||||
fwd, err = forward.New(
|
||||
forward.Stream(true),
|
||||
forward.PassHostHeader(frontend.PassHostHeader),
|
||||
forward.RoundTripper(roundTripper),
|
||||
forward.ErrorHandler(errorHandler),
|
||||
forward.Rewriter(rewriter),
|
||||
forward.ResponseModifier(responseModifier),
|
||||
forward.BufferPool(s.bufferPool),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
if s.tracingMiddleware.IsEnabled() {
|
||||
tm := s.tracingMiddleware.NewForwarderMiddleware(frontendName, frontend.Backend)
|
||||
|
||||
next := fwd
|
||||
fwd = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tm.ServeHTTP(w, r, next.ServeHTTP)
|
||||
})
|
||||
}
|
||||
|
||||
return fwd, nil
|
||||
}
|
||||
|
||||
func buildServerRoute(serverEntryPoint *serverEntryPoint, frontendName string, frontend *types.Frontend) (*types.ServerRoute, error) {
|
||||
serverRoute := &types.ServerRoute{Route: serverEntryPoint.httpRouter.GetHandler().NewRoute().Name(frontendName)}
|
||||
|
||||
priority := 0
|
||||
for routeName, route := range frontend.Routes {
|
||||
rls := rules.Rules{Route: serverRoute}
|
||||
newRoute, err := rls.Parse(route.Rule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating route for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
serverRoute.Route = newRoute
|
||||
|
||||
priority += len(route.Rule)
|
||||
log.Debugf("Creating route %s %s", routeName, route.Rule)
|
||||
}
|
||||
|
||||
if frontend.Priority > 0 {
|
||||
serverRoute.Route.Priority(frontend.Priority)
|
||||
} else {
|
||||
serverRoute.Route.Priority(priority)
|
||||
}
|
||||
|
||||
return serverRoute, nil
|
||||
}
|
||||
|
||||
func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
|
||||
providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration)
|
||||
s.defaultConfigurationValues(configMsg.Configuration)
|
||||
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
|
||||
|
||||
jsonConf, _ := json.Marshal(configMsg.Configuration)
|
||||
|
||||
log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
|
||||
|
||||
if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLS == nil {
|
||||
log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
|
||||
log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
||||
providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName]
|
||||
if !ok {
|
||||
providerConfigUpdateCh = make(chan types.ConfigMessage)
|
||||
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
|
||||
})
|
||||
}
|
||||
|
||||
providerConfigUpdateCh <- configMsg
|
||||
}
|
||||
|
||||
func (s *Server) defaultConfigurationValues(configuration *types.Configuration) {
|
||||
if configuration == nil || configuration.Frontends == nil {
|
||||
return
|
||||
}
|
||||
s.configureFrontends(configuration.Frontends)
|
||||
configureBackends(configuration.Backends)
|
||||
}
|
||||
|
||||
func (s *Server) configureFrontends(frontends map[string]*types.Frontend) {
|
||||
defaultEntrypoints := s.globalConfiguration.DefaultEntryPoints
|
||||
|
||||
for frontendName, frontend := range frontends {
|
||||
// default endpoints if not defined in frontends
|
||||
if len(frontend.EntryPoints) == 0 {
|
||||
frontend.EntryPoints = defaultEntrypoints
|
||||
}
|
||||
|
||||
frontendEntryPoints, undefinedEntryPoints := s.filterEntryPoints(frontend.EntryPoints)
|
||||
if len(undefinedEntryPoints) > 0 {
|
||||
log.Errorf("Undefined entry point(s) '%s' for frontend %s", strings.Join(undefinedEntryPoints, ","), frontendName)
|
||||
}
|
||||
|
||||
frontend.EntryPoints = frontendEntryPoints
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) filterEntryPoints(entryPoints []string) ([]string, []string) {
|
||||
var frontendEntryPoints []string
|
||||
var undefinedEntryPoints []string
|
||||
|
||||
for _, fepName := range entryPoints {
|
||||
var exist bool
|
||||
|
||||
for epName := range s.entryPoints {
|
||||
if epName == fepName {
|
||||
exist = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if exist {
|
||||
frontendEntryPoints = append(frontendEntryPoints, fepName)
|
||||
} else {
|
||||
undefinedEntryPoints = append(undefinedEntryPoints, fepName)
|
||||
}
|
||||
}
|
||||
|
||||
return frontendEntryPoints, undefinedEntryPoints
|
||||
}
|
||||
|
||||
func configureBackends(backends map[string]*types.Backend) {
|
||||
for backendName := range backends {
|
||||
backend := backends[backendName]
|
||||
if backend.LoadBalancer != nil && backend.LoadBalancer.Sticky {
|
||||
log.Warnf("Deprecated configuration found: %s. Please use %s.", "backend.LoadBalancer.Sticky", "backend.LoadBalancer.Stickiness")
|
||||
}
|
||||
|
||||
_, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
|
||||
if err == nil {
|
||||
if backend.LoadBalancer != nil && backend.LoadBalancer.Stickiness == nil && backend.LoadBalancer.Sticky {
|
||||
backend.LoadBalancer.Stickiness = &types.Stickiness{
|
||||
CookieName: "_TRAEFIK_BACKEND",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("Backend %s: %v", backendName, err)
|
||||
|
||||
var stickiness *types.Stickiness
|
||||
if backend.LoadBalancer != nil {
|
||||
if backend.LoadBalancer.Stickiness == nil {
|
||||
if backend.LoadBalancer.Sticky {
|
||||
stickiness = &types.Stickiness{
|
||||
CookieName: "_TRAEFIK_BACKEND",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stickiness = backend.LoadBalancer.Stickiness
|
||||
}
|
||||
}
|
||||
backend.LoadBalancer = &types.LoadBalancer{
|
||||
Method: "wrr",
|
||||
Stickiness: stickiness,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenConfigurations(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case configMsg, ok := <-s.configurationValidatedChan:
|
||||
if !ok || configMsg.Configuration == nil {
|
||||
return
|
||||
}
|
||||
s.loadConfiguration(configMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// throttleProviderConfigReload throttles the configuration reload speed for a single provider.
|
||||
// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration.
|
||||
// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing,
|
||||
// it will publish the last of the newly received configurations.
|
||||
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- types.ConfigMessage, in <-chan types.ConfigMessage, stop chan bool) {
|
||||
ring := channels.NewRingChannel(1)
|
||||
defer ring.Close()
|
||||
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case nextConfig := <-ring.Out():
|
||||
publish <- nextConfig.(types.ConfigMessage)
|
||||
time.Sleep(throttle)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case nextConfig := <-in:
|
||||
ring.In() <- nextConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildMatcherMiddlewares(serverRoute *types.ServerRoute, handler http.Handler) http.Handler {
|
||||
// path replace - This needs to always be the very last on the handler chain (first in the order in this function)
|
||||
// -- Replacing Path should happen at the very end of the Modifier chain, after all the Matcher+Modifiers ran
|
||||
if len(serverRoute.ReplacePath) > 0 {
|
||||
handler = &middlewares.ReplacePath{
|
||||
Path: serverRoute.ReplacePath,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
if len(serverRoute.ReplacePathRegex) > 0 {
|
||||
sp := strings.Split(serverRoute.ReplacePathRegex, " ")
|
||||
if len(sp) == 2 {
|
||||
handler = middlewares.NewReplacePathRegexHandler(sp[0], sp[1], handler)
|
||||
} else {
|
||||
log.Warnf("Invalid syntax for ReplacePathRegex: %s. Separate the regular expression and the replacement by a space.", serverRoute.ReplacePathRegex)
|
||||
}
|
||||
}
|
||||
|
||||
// add prefix - This needs to always be right before ReplacePath on the chain (second in order in this function)
|
||||
// -- Adding Path Prefix should happen after all *Strip Matcher+Modifiers ran, but before Replace (in case it's configured)
|
||||
if len(serverRoute.AddPrefix) > 0 {
|
||||
handler = &middlewares.AddPrefix{
|
||||
Prefix: serverRoute.AddPrefix,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// strip prefix
|
||||
if len(serverRoute.StripPrefixes) > 0 {
|
||||
handler = &middlewares.StripPrefix{
|
||||
Prefixes: serverRoute.StripPrefixes,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// strip prefix with regex
|
||||
if len(serverRoute.StripPrefixesRegex) > 0 {
|
||||
handler = middlewares.NewStripPrefixRegex(handler, serverRoute.StripPrefixesRegex)
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func (s *Server) postLoadConfiguration() {
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
activeConfig := s.currentConfigurations.Get().(types.Configurations)
|
||||
metrics.OnConfigurationUpdate(activeConfig)
|
||||
}
|
||||
|
||||
if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
|
||||
return
|
||||
}
|
||||
|
||||
if s.globalConfiguration.ACME.OnHostRule {
|
||||
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
|
||||
for _, config := range currentConfigurations {
|
||||
for _, frontend := range config.Frontends {
|
||||
|
||||
// check if one of the frontend entrypoints is configured with TLS
|
||||
// and is configured with ACME
|
||||
acmeEnabled := false
|
||||
for _, entryPoint := range frontend.EntryPoints {
|
||||
if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
|
||||
acmeEnabled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if acmeEnabled {
|
||||
for _, route := range frontend.Routes {
|
||||
rls := rules.Rules{}
|
||||
domains, err := rls.ParseDomains(route.Rule)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing domains: %v", err)
|
||||
} else {
|
||||
s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
|
||||
func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) (map[string]map[string]*tls.Certificate, error) {
|
||||
newEPCertificates := make(map[string]map[string]*tls.Certificate)
|
||||
// Get all certificates
|
||||
for _, config := range configurations {
|
||||
if config.TLS != nil && len(config.TLS) > 0 {
|
||||
if err := traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, defaultEntryPoints); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return newEPCertificates, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
|
||||
serverEntryPoints := make(map[string]*serverEntryPoint)
|
||||
for entryPointName, entryPoint := range s.entryPoints {
|
||||
serverEntryPoints[entryPointName] = &serverEntryPoint{
|
||||
httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()),
|
||||
onDemandListener: entryPoint.OnDemandListener,
|
||||
}
|
||||
if entryPoint.CertificateStore != nil {
|
||||
serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore.DynamicCerts
|
||||
} else {
|
||||
serverEntryPoints[entryPointName].certs = &safe.Safe{}
|
||||
}
|
||||
}
|
||||
return serverEntryPoints
|
||||
}
|
||||
|
||||
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
|
||||
rt := mux.NewRouter()
|
||||
rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(http.NotFound), "backend not found")
|
||||
rt.StrictSlash(true)
|
||||
rt.SkipClean(true)
|
||||
return rt
|
||||
}
|
||||
|
||||
func sortedFrontendNamesForConfig(configuration *types.Configuration) []string {
|
||||
var keys []string
|
||||
for key := range configuration.Frontends {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
484
server/server_configuration_test.go
Normal file
484
server/server_configuration_test.go
Normal file
|
@ -0,0 +1,484 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/rules"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
|
||||
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
|
||||
// generated from src/crypto/tls:
|
||||
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
|
||||
var (
|
||||
localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE-----
|
||||
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
|
||||
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
|
||||
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
|
||||
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
|
||||
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
|
||||
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
|
||||
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
|
||||
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
|
||||
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
|
||||
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
|
||||
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
|
||||
fblo6RBxUQ==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// LocalhostKey is the private key for localhostCert.
|
||||
localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
|
||||
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
|
||||
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
|
||||
AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
|
||||
3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
|
||||
uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
|
||||
qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
|
||||
jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
|
||||
fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
|
||||
fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
|
||||
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
|
||||
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
|
||||
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
type testLoadBalancer struct{}
|
||||
|
||||
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) Servers() []*url.URL {
|
||||
return []*url.URL{}
|
||||
}
|
||||
|
||||
func TestServerLoadConfigHealthCheckOptions(t *testing.T) {
|
||||
healthChecks := []*types.HealthCheck{
|
||||
nil,
|
||||
{
|
||||
Path: "/path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, lbMethod := range []string{"Wrr", "Drr"} {
|
||||
for _, healthCheck := range healthChecks {
|
||||
t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)},
|
||||
}
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {
|
||||
Configuration: &configuration.EntryPoint{
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: lbMethod,
|
||||
},
|
||||
HealthCheck: healthCheck,
|
||||
},
|
||||
},
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
CertFile: localhostCert,
|
||||
KeyFile: localhostKey,
|
||||
},
|
||||
EntryPoints: []string{"http"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedNumHealthCheckBackends := 0
|
||||
if healthCheck != nil {
|
||||
expectedNumHealthCheckBackends = 1
|
||||
}
|
||||
assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerLoadConfigEmptyBasicAuth(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
EntryPoints: configuration.EntryPoints{
|
||||
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
BasicAuth: []string{""},
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: "Wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entryPoints := map[string]EntryPoint{}
|
||||
for key, value := range globalConfig.EntryPoints {
|
||||
entryPoints[key] = EntryPoint{
|
||||
Configuration: value,
|
||||
}
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http", "https"},
|
||||
}
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}},
|
||||
"http": {Configuration: &configuration.EntryPoint{}},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
CertFile: localhostCert,
|
||||
KeyFile: localhostKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil {
|
||||
t.Fatalf("got error: %s", err)
|
||||
} else if mapEntryPoints["https"].certs.Get() == nil {
|
||||
t.Fatal("got error: https entryPoint must have TLS certificates.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseBackend(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http"},
|
||||
}
|
||||
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
}},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": th.BuildConfiguration(
|
||||
th.WithFrontends(
|
||||
th.WithFrontend("backend",
|
||||
th.WithFrontendName("frontend0"),
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))),
|
||||
th.WithFrontend("backend",
|
||||
th.WithFrontendName("frontend1"),
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")),
|
||||
th.WithBasicAuth("foo", "bar")),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServersNew(th.WithServerNew(testServer.URL))),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
serverEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("error loading config: %s", err)
|
||||
}
|
||||
|
||||
// Test that the /ok path returns a status 200.
|
||||
responseRecorderOk := &httptest.ResponseRecorder{}
|
||||
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
|
||||
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk)
|
||||
|
||||
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
|
||||
|
||||
// Test that the /unauthorized path returns a 401 because of
|
||||
// the basic authentication defined on the frontend.
|
||||
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
|
||||
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
|
||||
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
|
||||
}
|
||||
|
||||
func TestThrottleProviderConfigReload(t *testing.T) {
|
||||
throttleDuration := 30 * time.Millisecond
|
||||
publishConfig := make(chan types.ConfigMessage)
|
||||
providerConfig := make(chan types.ConfigMessage)
|
||||
stop := make(chan bool)
|
||||
defer func() {
|
||||
stop <- true
|
||||
}()
|
||||
|
||||
globalConfig := configuration.GlobalConfiguration{}
|
||||
server := NewServer(globalConfig, nil, nil)
|
||||
|
||||
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
|
||||
|
||||
publishedConfigCount := 0
|
||||
stopConsumeConfigs := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-stopConsumeConfigs:
|
||||
return
|
||||
case <-publishConfig:
|
||||
publishedConfigCount++
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// publish 5 new configs, one new config each 10 milliseconds
|
||||
for i := 0; i < 5; i++ {
|
||||
providerConfig <- types.ConfigMessage{}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// after 50 milliseconds 5 new configs were published
|
||||
// with a throttle duration of 30 milliseconds this means, we should have received 2 new configs
|
||||
assert.Equal(t, 2, publishedConfigCount, "times configs were published")
|
||||
|
||||
stopConsumeConfigs <- true
|
||||
|
||||
select {
|
||||
case <-publishConfig:
|
||||
// There should be exactly one more message that we receive after ~60 milliseconds since the start of the test.
|
||||
select {
|
||||
case <-publishConfig:
|
||||
t.Error("extra config publication found")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Last config was not published in time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerMultipleFrontendRules(t *testing.T) {
|
||||
testCases := []struct {
|
||||
expression string
|
||||
requestURL string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
expression: "Host:foo.bar",
|
||||
requestURL: "http://foo.bar",
|
||||
expectedURL: "http://foo.bar",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefix:/management;ReplacePath:/health",
|
||||
requestURL: "http://foo.bar/management",
|
||||
expectedURL: "http://foo.bar/health",
|
||||
},
|
||||
{
|
||||
expression: "Host:foo.bar;AddPrefix:/blah",
|
||||
requestURL: "http://foo.bar/baz",
|
||||
expectedURL: "http://foo.bar/blah/baz",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}",
|
||||
requestURL: "http://foo.bar/one/some/12345/four",
|
||||
expectedURL: "http://foo.bar/four",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero",
|
||||
requestURL: "http://foo.bar/one/some/12345/four",
|
||||
expectedURL: "http://foo.bar/zero/four",
|
||||
},
|
||||
{
|
||||
expression: "AddPrefix:/blah;ReplacePath:/baz",
|
||||
requestURL: "http://foo.bar/hello",
|
||||
expectedURL: "http://foo.bar/baz",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStrip:/management;ReplacePath:/health",
|
||||
requestURL: "http://foo.bar/management",
|
||||
expectedURL: "http://foo.bar/health",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.expression, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := mux.NewRouter()
|
||||
route := router.NewRoute()
|
||||
serverRoute := &types.ServerRoute{Route: route}
|
||||
rls := &rules.Rules{Route: serverRoute}
|
||||
|
||||
expression := test.expression
|
||||
routeResult, err := rls.Parse(expression)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error while building route for %s: %+v", expression, err)
|
||||
}
|
||||
|
||||
request := th.MustNewRequest(http.MethodGet, test.requestURL, nil)
|
||||
routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult})
|
||||
|
||||
if !routeMatch {
|
||||
t.Fatalf("Rule %s doesn't match", expression)
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, test.expectedURL, r.URL.String(), "URL")
|
||||
})
|
||||
|
||||
hd := buildMatcherMiddlewares(serverRoute, handler)
|
||||
serverRoute.Route.Handler(hd)
|
||||
|
||||
serverRoute.Route.GetHandler().ServeHTTP(nil, request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerBuildHealthCheckOptions(t *testing.T) {
|
||||
lb := &testLoadBalancer{}
|
||||
globalInterval := 15 * time.Second
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
hc *types.HealthCheck
|
||||
expectedOpts *healthcheck.Options
|
||||
}{
|
||||
{
|
||||
desc: "nil health check",
|
||||
hc: nil,
|
||||
expectedOpts: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty path",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "",
|
||||
},
|
||||
expectedOpts: nil,
|
||||
},
|
||||
{
|
||||
desc: "unparseable interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "unparseable",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "sub-zero interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "-42s",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "parseable interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "5m",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: 5 * time.Minute,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := buildHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)})
|
||||
assert.Equal(t, test.expectedOpts, opts, "health check options")
|
||||
})
|
||||
}
|
||||
}
|
428
server/server_loadbalancer.go
Normal file
428
server/server_loadbalancer.go
Normal file
|
@ -0,0 +1,428 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/server/cookie"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/vulcand/oxy/buffer"
|
||||
"github.com/vulcand/oxy/connlimit"
|
||||
"github.com/vulcand/oxy/ratelimit"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (s *Server) buildBalancerMiddlewares(frontendName string, frontend *types.Frontend, backend *types.Backend, fwd http.Handler) (http.Handler, *healthcheck.BackendConfig, error) {
|
||||
balancer, err := s.buildLoadBalancer(frontendName, frontend.Backend, backend, fwd)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Health Check
|
||||
var backendHealthCheck *healthcheck.BackendConfig
|
||||
if hcOpts := buildHealthCheckOptions(balancer, frontend.Backend, backend.HealthCheck, s.globalConfiguration.HealthCheck); hcOpts != nil {
|
||||
log.Debugf("Setting up backend health check %s", *hcOpts)
|
||||
|
||||
hcOpts.Transport = s.defaultForwardingRoundTripper
|
||||
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, frontend.Backend)
|
||||
}
|
||||
|
||||
// Empty (backend with no servers)
|
||||
var lb http.Handler = middlewares.NewEmptyBackendHandler(balancer)
|
||||
|
||||
// Rate Limit
|
||||
if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 {
|
||||
handler, err := buildRateLimiter(lb, frontend.RateLimit)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating rate limiter: %v", err)
|
||||
}
|
||||
|
||||
lb = s.wrapHTTPHandlerWithAccessLog(
|
||||
s.tracingMiddleware.NewHTTPHandlerWrapper("Rate limit", handler, false),
|
||||
fmt.Sprintf("rate limit for %s", frontendName),
|
||||
)
|
||||
}
|
||||
|
||||
// Max Connections
|
||||
if backend.MaxConn != nil && backend.MaxConn.Amount != 0 {
|
||||
log.Debugf("Creating load-balancer connection limit")
|
||||
|
||||
handler, err := buildMaxConn(lb, backend.MaxConn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
lb = s.wrapHTTPHandlerWithAccessLog(handler, fmt.Sprintf("connection limit for %s", frontendName))
|
||||
}
|
||||
|
||||
// Retry
|
||||
if s.globalConfiguration.Retry != nil {
|
||||
handler := s.buildRetryMiddleware(lb, s.globalConfiguration.Retry, len(backend.Servers), frontend.Backend)
|
||||
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Retry", handler, false)
|
||||
}
|
||||
|
||||
// Buffering
|
||||
if backend.Buffering != nil {
|
||||
handler, err := buildBufferingMiddleware(lb, backend.Buffering)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error setting up buffering middleware: %s", err)
|
||||
}
|
||||
|
||||
// TODO refactor ?
|
||||
lb = handler
|
||||
}
|
||||
|
||||
// Circuit Breaker
|
||||
if backend.CircuitBreaker != nil {
|
||||
log.Debugf("Creating circuit breaker %s", backend.CircuitBreaker.Expression)
|
||||
|
||||
expression := backend.CircuitBreaker.Expression
|
||||
circuitBreaker, err := middlewares.NewCircuitBreaker(lb, expression, middlewares.NewCircuitBreakerOptions(expression))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating circuit breaker: %v", err)
|
||||
}
|
||||
|
||||
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Circuit breaker", circuitBreaker, false)
|
||||
}
|
||||
|
||||
return lb, backendHealthCheck, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildLoadBalancer(frontendName string, backendName string, backend *types.Backend, fwd http.Handler) (healthcheck.BalancerHandler, error) {
|
||||
var rr *roundrobin.RoundRobin
|
||||
var saveFrontend http.Handler
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
saveBackend := accesslog.NewSaveBackend(fwd, backendName)
|
||||
saveFrontend = accesslog.NewSaveFrontend(saveBackend, frontendName)
|
||||
rr, _ = roundrobin.New(saveFrontend)
|
||||
} else {
|
||||
rr, _ = roundrobin.New(fwd)
|
||||
}
|
||||
|
||||
var stickySession *roundrobin.StickySession
|
||||
var cookieName string
|
||||
if stickiness := backend.LoadBalancer.Stickiness; stickiness != nil {
|
||||
cookieName = cookie.GetName(stickiness.CookieName, backendName)
|
||||
stickySession = roundrobin.NewStickySession(cookieName)
|
||||
}
|
||||
|
||||
lbMethod, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading load balancer method '%+v' for frontend %s: %v", backend.LoadBalancer, frontendName, err)
|
||||
}
|
||||
|
||||
var lb healthcheck.BalancerHandler
|
||||
|
||||
switch lbMethod {
|
||||
case types.Drr:
|
||||
log.Debug("Creating load-balancer drr")
|
||||
|
||||
if stickySession != nil {
|
||||
log.Debugf("Sticky session with cookie %v", cookieName)
|
||||
|
||||
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb, err = roundrobin.NewRebalancer(rr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case types.Wrr:
|
||||
log.Debug("Creating load-balancer wrr")
|
||||
|
||||
if stickySession != nil {
|
||||
log.Debugf("Sticky session with cookie %v", cookieName)
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
lb, err = roundrobin.New(saveFrontend, roundrobin.EnableStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lb = rr
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid load-balancing method %q", lbMethod)
|
||||
}
|
||||
|
||||
if err := s.configureLBServers(lb, backend, backendName); err != nil {
|
||||
return nil, fmt.Errorf("error configuring load balancer for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (s *Server) configureLBServers(lb healthcheck.BalancerHandler, backend *types.Backend, backendName string) error {
|
||||
for name, srv := range backend.Servers {
|
||||
u, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
log.Debugf("Creating server %s at %s with weight %d", name, u, srv.Weight)
|
||||
|
||||
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
|
||||
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
s.metricsRegistry.BackendServerUpGauge().With("backend", backendName, "url", srv.URL).Set(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one
|
||||
// given a custom TLS configuration is passed and the passTLSCert option is set to true.
|
||||
func (s *Server) getRoundTripper(entryPointName string, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) {
|
||||
if passTLSCert {
|
||||
tlsConfig, err := createClientTLSConfig(entryPointName, tls)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err)
|
||||
}
|
||||
|
||||
transport, err := createHTTPTransport(s.globalConfiguration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP transport: %v", err)
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
return s.defaultForwardingRoundTripper, nil
|
||||
}
|
||||
|
||||
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
|
||||
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
|
||||
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
|
||||
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
|
||||
// behaviour and backwards compatibility issues.
|
||||
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: configuration.DefaultDialTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
|
||||
if globalConfiguration.ForwardingTimeouts != nil {
|
||||
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
|
||||
Transport: &http2.Transport{
|
||||
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return net.Dial(netw, addr)
|
||||
},
|
||||
AllowHTTP: true,
|
||||
},
|
||||
})
|
||||
|
||||
if globalConfiguration.ForwardingTimeouts != nil {
|
||||
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
if globalConfiguration.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
if len(globalConfiguration.RootCAs) > 0 {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
|
||||
}
|
||||
}
|
||||
|
||||
err := http2.ConfigureTransport(transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func createRootCACertPool(rootCAs traefiktls.RootCAs) *x509.CertPool {
|
||||
roots := x509.NewCertPool()
|
||||
|
||||
for _, cert := range rootCAs {
|
||||
certContent, err := cert.Read()
|
||||
if err != nil {
|
||||
log.Error("Error while read RootCAs", err)
|
||||
continue
|
||||
}
|
||||
roots.AppendCertsFromPEM(certContent)
|
||||
}
|
||||
|
||||
return roots
|
||||
}
|
||||
|
||||
func createClientTLSConfig(entryPointName string, tlsOption *traefiktls.TLS) (*tls.Config, error) {
|
||||
if tlsOption == nil {
|
||||
return nil, errors.New("no TLS provided")
|
||||
}
|
||||
|
||||
config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tlsOption.ClientCAFiles) > 0 {
|
||||
log.Warnf("Deprecated configuration found during client TLS configuration creation: %s. Please use %s (which allows to make the CA Files optional).", "tls.ClientCAFiles", "tls.ClientCA.files")
|
||||
tlsOption.ClientCA.Files = tlsOption.ClientCAFiles
|
||||
tlsOption.ClientCA.Optional = false
|
||||
}
|
||||
|
||||
if len(tlsOption.ClientCA.Files) > 0 {
|
||||
pool := x509.NewCertPool()
|
||||
for _, caFile := range tlsOption.ClientCA.Files {
|
||||
data, err := ioutil.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !pool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
|
||||
}
|
||||
}
|
||||
config.RootCAs = pool
|
||||
}
|
||||
|
||||
config.BuildNameToCertificate()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildRetryMiddleware(handler http.Handler, retry *configuration.Retry, countServers int, backendName string) http.Handler {
|
||||
retryListeners := middlewares.RetryListeners{}
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
retryListeners = append(retryListeners, middlewares.NewMetricsRetryListener(s.metricsRegistry, backendName))
|
||||
}
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
retryListeners = append(retryListeners, &accesslog.SaveRetries{})
|
||||
}
|
||||
|
||||
retryAttempts := countServers
|
||||
if retry.Attempts > 0 {
|
||||
retryAttempts = retry.Attempts
|
||||
}
|
||||
|
||||
log.Debugf("Creating retries max attempts %d", retryAttempts)
|
||||
|
||||
return middlewares.NewRetry(retryAttempts, handler, retryListeners)
|
||||
}
|
||||
|
||||
func buildRateLimiter(handler http.Handler, rlConfig *types.RateLimit) (http.Handler, error) {
|
||||
extractFunc, err := utils.NewExtractor(rlConfig.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("Creating load-balancer rate limiter")
|
||||
|
||||
rateSet := ratelimit.NewRateSet()
|
||||
for _, rate := range rlConfig.RateSet {
|
||||
if err := rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ratelimit.New(handler, extractFunc, rateSet)
|
||||
}
|
||||
|
||||
func buildBufferingMiddleware(handler http.Handler, config *types.Buffering) (http.Handler, error) {
|
||||
log.Debugf("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
|
||||
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes,
|
||||
config.MaxResponseBodyBytes, config.RetryExpression)
|
||||
|
||||
return buffer.New(
|
||||
handler,
|
||||
buffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
|
||||
buffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
|
||||
buffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
|
||||
buffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
|
||||
buffer.CondSetter(len(config.RetryExpression) > 0, buffer.Retry(config.RetryExpression)),
|
||||
)
|
||||
}
|
||||
|
||||
func buildMaxConn(lb http.Handler, maxConns *types.MaxConn) (http.Handler, error) {
|
||||
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Creating load-balancer connection limit")
|
||||
|
||||
handler, err := connlimit.New(lb, extractFunc, maxConns.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
func buildHealthCheckOptions(lb healthcheck.BalancerHandler, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options {
|
||||
if hc == nil || hc.Path == "" || hcConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
interval := time.Duration(hcConfig.Interval)
|
||||
if hc.Interval != "" {
|
||||
intervalOverride, err := time.ParseDuration(hc.Interval)
|
||||
if err != nil {
|
||||
log.Errorf("Illegal health check interval for backend '%s': %s", backend, err)
|
||||
} else if intervalOverride <= 0 {
|
||||
log.Errorf("Health check interval smaller than zero for backend '%s', backend", backend)
|
||||
} else {
|
||||
interval = intervalOverride
|
||||
}
|
||||
}
|
||||
|
||||
return &healthcheck.Options{
|
||||
Scheme: hc.Scheme,
|
||||
Path: hc.Path,
|
||||
Port: hc.Port,
|
||||
Interval: interval,
|
||||
LB: lb,
|
||||
Hostname: hc.Hostname,
|
||||
Headers: hc.Headers,
|
||||
}
|
||||
}
|
81
server/server_loadbalancer_test.go
Normal file
81
server/server_loadbalancer_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigureBackends(t *testing.T) {
|
||||
validMethod := "Drr"
|
||||
defaultMethod := "wrr"
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
lb *types.LoadBalancer
|
||||
expectedMethod string
|
||||
expectedStickiness *types.Stickiness
|
||||
}{
|
||||
{
|
||||
desc: "valid load balancer method with sticky enabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: validMethod,
|
||||
Stickiness: &types.Stickiness{},
|
||||
},
|
||||
expectedMethod: validMethod,
|
||||
expectedStickiness: &types.Stickiness{},
|
||||
},
|
||||
{
|
||||
desc: "valid load balancer method with sticky disabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: validMethod,
|
||||
Stickiness: nil,
|
||||
},
|
||||
expectedMethod: validMethod,
|
||||
},
|
||||
{
|
||||
desc: "invalid load balancer method with sticky enabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: "Invalid",
|
||||
Stickiness: &types.Stickiness{},
|
||||
},
|
||||
expectedMethod: defaultMethod,
|
||||
expectedStickiness: &types.Stickiness{},
|
||||
},
|
||||
{
|
||||
desc: "invalid load balancer method with sticky disabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: "Invalid",
|
||||
Stickiness: nil,
|
||||
},
|
||||
expectedMethod: defaultMethod,
|
||||
},
|
||||
{
|
||||
desc: "missing load balancer",
|
||||
lb: nil,
|
||||
expectedMethod: defaultMethod,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
backend := &types.Backend{
|
||||
LoadBalancer: test.lb,
|
||||
}
|
||||
|
||||
configureBackends(map[string]*types.Backend{
|
||||
"backend": backend,
|
||||
})
|
||||
|
||||
expected := types.LoadBalancer{
|
||||
Method: test.expectedMethod,
|
||||
Stickiness: test.expectedStickiness,
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, *backend.LoadBalancer)
|
||||
})
|
||||
}
|
||||
}
|
316
server/server_middlewares.go
Normal file
316
server/server_middlewares.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
mauth "github.com/containous/traefik/middlewares/auth"
|
||||
"github.com/containous/traefik/middlewares/errorpages"
|
||||
"github.com/containous/traefik/middlewares/redirect"
|
||||
"github.com/containous/traefik/types"
|
||||
thoas_stats "github.com/thoas/stats"
|
||||
"github.com/unrolled/secure"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
type handlerPostConfig func(backendsHandlers map[string]http.Handler) error
|
||||
|
||||
type modifyResponse func(*http.Response) error
|
||||
|
||||
func (s *Server) buildMiddlewares(frontendName string, frontend *types.Frontend,
|
||||
backends map[string]*types.Backend,
|
||||
entryPointName string, entryPoint *configuration.EntryPoint,
|
||||
providerName string) ([]negroni.Handler, modifyResponse, handlerPostConfig, error) {
|
||||
|
||||
var middle []negroni.Handler
|
||||
var postConfig handlerPostConfig
|
||||
|
||||
// Error pages
|
||||
if len(frontend.Errors) > 0 {
|
||||
handlers, err := buildErrorPagesMiddleware(frontendName, frontend, backends, entryPointName, providerName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postConfig = errorPagesPostConfig(handlers)
|
||||
|
||||
for _, handler := range handlers {
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
handler := middlewares.NewBackendMetricsMiddleware(s.metricsRegistry, frontend.Backend)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Whitelist
|
||||
ipWhitelistMiddleware, err := buildIPWhiteLister(frontend.WhiteList, frontend.WhitelistSourceRange)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error creating IP Whitelister: %s", err)
|
||||
}
|
||||
if ipWhitelistMiddleware != nil {
|
||||
log.Debugf("Configured IP Whitelists: %v", frontend.WhiteList.SourceRange)
|
||||
|
||||
handler := s.tracingMiddleware.NewNegroniHandlerWrapper(
|
||||
"IP whitelist",
|
||||
s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for %s", frontendName)),
|
||||
false)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Redirect
|
||||
if frontend.Redirect != nil && entryPointName != frontend.Redirect.EntryPoint {
|
||||
rewrite, err := s.buildRedirectHandler(entryPointName, frontend.Redirect)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error creating Frontend Redirect: %v", err)
|
||||
}
|
||||
|
||||
handler := s.wrapNegroniHandlerWithAccessLog(rewrite, fmt.Sprintf("frontend redirect for %s", frontendName))
|
||||
middle = append(middle, handler)
|
||||
|
||||
log.Debugf("Frontend %s redirect created", frontendName)
|
||||
}
|
||||
|
||||
// Header
|
||||
headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers)
|
||||
if headerMiddleware != nil {
|
||||
log.Debugf("Adding header middleware for frontend %s", frontendName)
|
||||
|
||||
handler := s.tracingMiddleware.NewNegroniHandlerWrapper("Header", headerMiddleware, false)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Secure
|
||||
secureMiddleware := middlewares.NewSecure(frontend.Headers)
|
||||
if secureMiddleware != nil {
|
||||
log.Debugf("Adding secure middleware for frontend %s", frontendName)
|
||||
|
||||
handler := negroni.HandlerFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Basic auth
|
||||
if len(frontend.BasicAuth) > 0 {
|
||||
log.Debugf("Adding basic authentication for frontend %s", frontendName)
|
||||
|
||||
authMiddleware, err := s.buildBasicAuthMiddleware(frontend.BasicAuth)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
handler := s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Basic Auth for %s", frontendName))
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
return middle, buildModifyResponse(secureMiddleware, headerMiddleware), postConfig, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildServerEntryPointMiddlewares(serverEntryPointName string, serverEntryPoint *serverEntryPoint) ([]negroni.Handler, error) {
|
||||
serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()}
|
||||
|
||||
if s.tracingMiddleware.IsEnabled() {
|
||||
serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(serverEntryPointName))
|
||||
}
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware)
|
||||
}
|
||||
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, serverEntryPointName))
|
||||
}
|
||||
|
||||
if s.globalConfiguration.API != nil {
|
||||
if s.globalConfiguration.API.Stats == nil {
|
||||
s.globalConfiguration.API.Stats = thoas_stats.New()
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.Stats)
|
||||
if s.globalConfiguration.API.Statistics != nil {
|
||||
if s.globalConfiguration.API.StatsRecorder == nil {
|
||||
s.globalConfiguration.API.StatsRecorder = middlewares.NewStatsRecorder(s.globalConfiguration.API.Statistics.RecentErrors)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.StatsRecorder)
|
||||
}
|
||||
}
|
||||
|
||||
if s.entryPoints[serverEntryPointName].Configuration.Auth != nil {
|
||||
authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[serverEntryPointName].Configuration.Auth, s.tracingMiddleware)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create authentication middleware: %v", err)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", serverEntryPointName)))
|
||||
}
|
||||
|
||||
if s.entryPoints[serverEntryPointName].Configuration.Compress {
|
||||
serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{})
|
||||
}
|
||||
|
||||
ipWhitelistMiddleware, err := buildIPWhiteLister(
|
||||
s.entryPoints[serverEntryPointName].Configuration.WhiteList,
|
||||
s.entryPoints[serverEntryPointName].Configuration.WhitelistSourceRange)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ip whitelist middleware: %v", err)
|
||||
}
|
||||
if ipWhitelistMiddleware != nil {
|
||||
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", serverEntryPointName)))
|
||||
}
|
||||
|
||||
return serverMiddlewares, nil
|
||||
}
|
||||
|
||||
func errorPagesPostConfig(epHandlers []*errorpages.Handler) handlerPostConfig {
|
||||
return func(backendsHandlers map[string]http.Handler) error {
|
||||
for _, errorPageHandler := range epHandlers {
|
||||
if handler, ok := backendsHandlers[errorPageHandler.BackendName]; ok {
|
||||
err := errorPageHandler.PostLoad(handler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure error pages for backend %s: %v", errorPageHandler.BackendName, err)
|
||||
}
|
||||
} else {
|
||||
err := errorPageHandler.PostLoad(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure error pages for %s: %v", errorPageHandler.FallbackURL, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildErrorPagesMiddleware(frontendName string, frontend *types.Frontend, backends map[string]*types.Backend, entryPointName string, providerName string) ([]*errorpages.Handler, error) {
|
||||
var errorPageHandlers []*errorpages.Handler
|
||||
|
||||
for errorPageName, errorPage := range frontend.Errors {
|
||||
if frontend.Backend == errorPage.Backend {
|
||||
log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).",
|
||||
errorPageName, frontendName, errorPage.Backend)
|
||||
} else if backends[errorPage.Backend] == nil {
|
||||
log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.",
|
||||
errorPageName, frontendName, errorPage.Backend)
|
||||
} else {
|
||||
errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating error pages: %v", err)
|
||||
}
|
||||
|
||||
if errorPageServer, ok := backends[errorPage.Backend].Servers["error"]; ok {
|
||||
errorPagesHandler.FallbackURL = errorPageServer.URL
|
||||
}
|
||||
|
||||
errorPageHandlers = append(errorPageHandlers, errorPagesHandler)
|
||||
}
|
||||
}
|
||||
|
||||
return errorPageHandlers, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildBasicAuthMiddleware(authData []string) (*mauth.Authenticator, error) {
|
||||
users := types.Users{}
|
||||
for _, user := range authData {
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
auth := &types.Auth{}
|
||||
auth.Basic = &types.Basic{
|
||||
Users: users,
|
||||
}
|
||||
|
||||
authMiddleware, err := mauth.NewAuthenticator(auth, s.tracingMiddleware)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating Basic Auth: %v", err)
|
||||
}
|
||||
|
||||
return authMiddleware, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildEntryPointRedirect() (map[string]negroni.Handler, error) {
|
||||
redirectHandlers := map[string]negroni.Handler{}
|
||||
|
||||
for entryPointName, ep := range s.entryPoints {
|
||||
entryPoint := ep.Configuration
|
||||
|
||||
if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint {
|
||||
handler, err := s.buildRedirectHandler(entryPointName, entryPoint.Redirect)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading configuration for entrypoint %s: %v", entryPointName, err)
|
||||
}
|
||||
|
||||
handlerToUse := s.wrapNegroniHandlerWithAccessLog(handler, fmt.Sprintf("entrypoint redirect for %s", entryPointName))
|
||||
redirectHandlers[entryPointName] = handlerToUse
|
||||
}
|
||||
}
|
||||
|
||||
return redirectHandlers, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) {
|
||||
// entry point redirect
|
||||
if len(opt.EntryPoint) > 0 {
|
||||
entryPoint := s.entryPoints[opt.EntryPoint].Configuration
|
||||
if entryPoint == nil {
|
||||
return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName)
|
||||
}
|
||||
log.Debugf("Creating entry point redirect %s -> %s", srcEntryPointName, opt.EntryPoint)
|
||||
return redirect.NewEntryPointHandler(entryPoint, opt.Permanent)
|
||||
}
|
||||
|
||||
// regex redirect
|
||||
redirection, err := redirect.NewRegexHandler(opt.Regex, opt.Replacement, opt.Permanent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Creating regex redirect %s -> %s -> %s", srcEntryPointName, opt.Regex, opt.Replacement)
|
||||
|
||||
return redirection, nil
|
||||
}
|
||||
|
||||
func buildIPWhiteLister(whiteList *types.WhiteList, wlRange []string) (*middlewares.IPWhiteLister, error) {
|
||||
if whiteList != nil &&
|
||||
len(whiteList.SourceRange) > 0 {
|
||||
return middlewares.NewIPWhiteLister(whiteList.SourceRange, whiteList.UseXForwardedFor)
|
||||
} else if len(wlRange) > 0 {
|
||||
return middlewares.NewIPWhiteLister(wlRange, false)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Server) wrapNegroniHandlerWithAccessLog(handler negroni.Handler, frontendName string) negroni.Handler {
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
saveBackend := accesslog.NewSaveNegroniBackend(handler, "Træfik")
|
||||
saveFrontend := accesslog.NewSaveNegroniFrontend(saveBackend, frontendName)
|
||||
return saveFrontend
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
func (s *Server) wrapHTTPHandlerWithAccessLog(handler http.Handler, frontendName string) http.Handler {
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
saveBackend := accesslog.NewSaveBackend(handler, "Træfik")
|
||||
saveFrontend := accesslog.NewSaveFrontend(saveBackend, frontendName)
|
||||
return saveFrontend
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
func buildModifyResponse(secure *secure.Secure, header *middlewares.HeaderStruct) func(res *http.Response) error {
|
||||
return func(res *http.Response) error {
|
||||
if secure != nil {
|
||||
if err := secure.ModifyResponseHeaders(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if header != nil {
|
||||
if err := header.ModifyResponseHeaders(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
253
server/server_middlewares_test.go
Normal file
253
server/server_middlewares_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestServerEntryPointWhitelistConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entrypoint *configuration.EntryPoint
|
||||
expectMiddleware bool
|
||||
}{
|
||||
{
|
||||
desc: "no whitelist middleware if no config on entrypoint",
|
||||
entrypoint: &configuration.EntryPoint{
|
||||
Address: ":0",
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
expectMiddleware: false,
|
||||
},
|
||||
{
|
||||
desc: "whitelist middleware should be added if configured on entrypoint",
|
||||
entrypoint: &configuration.EntryPoint{
|
||||
Address: ":0",
|
||||
WhitelistSourceRange: []string{
|
||||
"127.0.0.1/32",
|
||||
},
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
expectMiddleware: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := Server{
|
||||
globalConfiguration: configuration.GlobalConfiguration{},
|
||||
metricsRegistry: metrics.NewVoidRegistry(),
|
||||
entryPoints: map[string]EntryPoint{
|
||||
"test": {
|
||||
Configuration: test.entrypoint,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv.serverEntryPoints = srv.buildServerEntryPoints()
|
||||
srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"])
|
||||
handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni)
|
||||
|
||||
found := false
|
||||
for _, handler := range handler.Handlers() {
|
||||
if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if found && !test.expectMiddleware {
|
||||
t.Error("ip whitelist middleware was installed even though it should not")
|
||||
}
|
||||
|
||||
if !found && test.expectMiddleware {
|
||||
t.Error("ip whitelist middleware was not installed even though it should have")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildIPWhiteLister(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whitelistSourceRange []string
|
||||
whiteList *types.WhiteList
|
||||
middlewareConfigured bool
|
||||
errMessage string
|
||||
}{
|
||||
{
|
||||
desc: "no whitelists configured",
|
||||
whitelistSourceRange: nil,
|
||||
middlewareConfigured: false,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "whitelists configured (deprecated)",
|
||||
whitelistSourceRange: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
middlewareConfigured: true,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "invalid whitelists configured (deprecated)",
|
||||
whitelistSourceRange: []string{
|
||||
"foo",
|
||||
},
|
||||
middlewareConfigured: false,
|
||||
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
|
||||
},
|
||||
{
|
||||
desc: "whitelists configured",
|
||||
whiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
UseXForwardedFor: false,
|
||||
},
|
||||
middlewareConfigured: true,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "invalid whitelists configured (deprecated)",
|
||||
whiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"foo",
|
||||
},
|
||||
UseXForwardedFor: false,
|
||||
},
|
||||
middlewareConfigured: false,
|
||||
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
middleware, err := buildIPWhiteLister(test.whiteList, test.whitelistSourceRange)
|
||||
|
||||
if test.errMessage != "" {
|
||||
require.EqualError(t, err, test.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
if test.middlewareConfigured {
|
||||
require.NotNil(t, middleware, "not expected middleware to be configured")
|
||||
} else {
|
||||
require.Nil(t, middleware, "expected middleware to be configured")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectHandler(t *testing.T) {
|
||||
srv := Server{
|
||||
globalConfiguration: configuration.GlobalConfiguration{},
|
||||
entryPoints: map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{Address: ":80"}},
|
||||
"https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
srcEntryPointName string
|
||||
url string
|
||||
entryPoint *configuration.EntryPoint
|
||||
redirect *types.Redirect
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
desc: "redirect regex",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo.com",
|
||||
redirect: &types.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foobar.com",
|
||||
},
|
||||
{
|
||||
desc: "redirect entry point",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo:80",
|
||||
redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foo:443",
|
||||
},
|
||||
{
|
||||
desc: "redirect entry point with regex (ignored)",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo.com:80",
|
||||
redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foo.com:443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := th.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Location", "fail")
|
||||
}))
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
})
|
||||
}
|
||||
}
|
|
@ -22,12 +22,12 @@ func (s *Server) listenSignals() {
|
|||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
if err := s.accessLoggerMiddleware.Rotate(); err != nil {
|
||||
log.Errorf("Error rotating access log: %s", err)
|
||||
log.Errorf("Error rotating access log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := log.RotateFile(); err != nil {
|
||||
log.Errorf("Error rotating traefik log: %s", err)
|
||||
log.Errorf("Error rotating traefik log: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,83 +2,22 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/rules"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/unrolled/secure"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
|
||||
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
|
||||
// generated from src/crypto/tls:
|
||||
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
|
||||
var (
|
||||
localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE-----
|
||||
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
|
||||
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
|
||||
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
|
||||
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
|
||||
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
|
||||
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
|
||||
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
|
||||
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
|
||||
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
|
||||
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
|
||||
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
|
||||
fblo6RBxUQ==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// LocalhostKey is the private key for localhostCert.
|
||||
localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
|
||||
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
|
||||
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
|
||||
AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
|
||||
3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
|
||||
uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
|
||||
qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
|
||||
jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
|
||||
fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
|
||||
fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
|
||||
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
|
||||
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
|
||||
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
type testLoadBalancer struct{}
|
||||
|
||||
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) Servers() []*url.URL {
|
||||
return []*url.URL{}
|
||||
}
|
||||
|
||||
func TestPrepareServerTimeouts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
@ -282,559 +221,6 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c
|
|||
return server, stop, invokeStopChan
|
||||
}
|
||||
|
||||
func TestThrottleProviderConfigReload(t *testing.T) {
|
||||
throttleDuration := 30 * time.Millisecond
|
||||
publishConfig := make(chan types.ConfigMessage)
|
||||
providerConfig := make(chan types.ConfigMessage)
|
||||
stop := make(chan bool)
|
||||
defer func() {
|
||||
stop <- true
|
||||
}()
|
||||
|
||||
globalConfig := configuration.GlobalConfiguration{}
|
||||
server := NewServer(globalConfig, nil, nil)
|
||||
|
||||
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
|
||||
|
||||
publishedConfigCount := 0
|
||||
stopConsumeConfigs := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-stopConsumeConfigs:
|
||||
return
|
||||
case <-publishConfig:
|
||||
publishedConfigCount++
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// publish 5 new configs, one new config each 10 milliseconds
|
||||
for i := 0; i < 5; i++ {
|
||||
providerConfig <- types.ConfigMessage{}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// after 50 milliseconds 5 new configs were published
|
||||
// with a throttle duration of 30 milliseconds this means, we should have received 2 new configs
|
||||
assert.Equal(t, 2, publishedConfigCount, "times configs were published")
|
||||
|
||||
stopConsumeConfigs <- true
|
||||
|
||||
select {
|
||||
case <-publishConfig:
|
||||
// There should be exactly one more message that we receive after ~60 milliseconds since the start of the test.
|
||||
select {
|
||||
case <-publishConfig:
|
||||
t.Error("extra config publication found")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Last config was not published in time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerMultipleFrontendRules(t *testing.T) {
|
||||
testCases := []struct {
|
||||
expression string
|
||||
requestURL string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
expression: "Host:foo.bar",
|
||||
requestURL: "http://foo.bar",
|
||||
expectedURL: "http://foo.bar",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefix:/management;ReplacePath:/health",
|
||||
requestURL: "http://foo.bar/management",
|
||||
expectedURL: "http://foo.bar/health",
|
||||
},
|
||||
{
|
||||
expression: "Host:foo.bar;AddPrefix:/blah",
|
||||
requestURL: "http://foo.bar/baz",
|
||||
expectedURL: "http://foo.bar/blah/baz",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}",
|
||||
requestURL: "http://foo.bar/one/some/12345/four",
|
||||
expectedURL: "http://foo.bar/four",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero",
|
||||
requestURL: "http://foo.bar/one/some/12345/four",
|
||||
expectedURL: "http://foo.bar/zero/four",
|
||||
},
|
||||
{
|
||||
expression: "AddPrefix:/blah;ReplacePath:/baz",
|
||||
requestURL: "http://foo.bar/hello",
|
||||
expectedURL: "http://foo.bar/baz",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStrip:/management;ReplacePath:/health",
|
||||
requestURL: "http://foo.bar/management",
|
||||
expectedURL: "http://foo.bar/health",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.expression, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := mux.NewRouter()
|
||||
route := router.NewRoute()
|
||||
serverRoute := &types.ServerRoute{Route: route}
|
||||
rules := &rules.Rules{Route: serverRoute}
|
||||
|
||||
expression := test.expression
|
||||
routeResult, err := rules.Parse(expression)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error while building route for %s: %+v", expression, err)
|
||||
}
|
||||
|
||||
request := th.MustNewRequest(http.MethodGet, test.requestURL, nil)
|
||||
routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult})
|
||||
|
||||
if !routeMatch {
|
||||
t.Fatalf("Rule %s doesn't match", expression)
|
||||
}
|
||||
|
||||
server := new(Server)
|
||||
|
||||
server.wireFrontendBackend(serverRoute, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, test.expectedURL, r.URL.String(), "URL")
|
||||
}))
|
||||
serverRoute.Route.GetHandler().ServeHTTP(nil, request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerLoadConfigHealthCheckOptions(t *testing.T) {
|
||||
healthChecks := []*types.HealthCheck{
|
||||
nil,
|
||||
{
|
||||
Path: "/path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, lbMethod := range []string{"Wrr", "Drr"} {
|
||||
for _, healthCheck := range healthChecks {
|
||||
t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
HealthCheck: &configuration.HealthCheckConfig{Interval: flaeg.Duration(5 * time.Second)},
|
||||
}
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {
|
||||
Configuration: &configuration.EntryPoint{
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: lbMethod,
|
||||
},
|
||||
HealthCheck: healthCheck,
|
||||
},
|
||||
},
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
CertFile: localhostCert,
|
||||
KeyFile: localhostKey,
|
||||
},
|
||||
EntryPoints: []string{"http"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedNumHealthCheckBackends := 0
|
||||
if healthCheck != nil {
|
||||
expectedNumHealthCheckBackends = 1
|
||||
}
|
||||
assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerParseHealthCheckOptions(t *testing.T) {
|
||||
lb := &testLoadBalancer{}
|
||||
globalInterval := 15 * time.Second
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
hc *types.HealthCheck
|
||||
expectedOpts *healthcheck.Options
|
||||
}{
|
||||
{
|
||||
desc: "nil health check",
|
||||
hc: nil,
|
||||
expectedOpts: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty path",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "",
|
||||
},
|
||||
expectedOpts: nil,
|
||||
},
|
||||
{
|
||||
desc: "unparseable interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "unparseable",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "sub-zero interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "-42s",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "parseable interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "5m",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: 5 * time.Minute,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := parseHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{Interval: flaeg.Duration(globalInterval)})
|
||||
assert.Equal(t, test.expectedOpts, opts, "health check options")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildIPWhiteLister(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whitelistSourceRange []string
|
||||
whiteList *types.WhiteList
|
||||
middlewareConfigured bool
|
||||
errMessage string
|
||||
}{
|
||||
{
|
||||
desc: "no whitelists configured",
|
||||
whitelistSourceRange: nil,
|
||||
middlewareConfigured: false,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "whitelists configured (deprecated)",
|
||||
whitelistSourceRange: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
middlewareConfigured: true,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "invalid whitelists configured (deprecated)",
|
||||
whitelistSourceRange: []string{
|
||||
"foo",
|
||||
},
|
||||
middlewareConfigured: false,
|
||||
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
|
||||
},
|
||||
{
|
||||
desc: "whitelists configured",
|
||||
whiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
UseXForwardedFor: false,
|
||||
},
|
||||
middlewareConfigured: true,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "invalid whitelists configured (deprecated)",
|
||||
whiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"foo",
|
||||
},
|
||||
UseXForwardedFor: false,
|
||||
},
|
||||
middlewareConfigured: false,
|
||||
errMessage: "parsing CIDR whitelist [foo]: parsing CIDR white list <nil>: invalid CIDR address: foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
middleware, err := buildIPWhiteLister(test.whiteList, test.whitelistSourceRange)
|
||||
|
||||
if test.errMessage != "" {
|
||||
require.EqualError(t, err, test.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
if test.middlewareConfigured {
|
||||
require.NotNil(t, middleware, "not expected middleware to be configured")
|
||||
} else {
|
||||
require.Nil(t, middleware, "expected middleware to be configured")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerLoadConfigEmptyBasicAuth(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
EntryPoints: configuration.EntryPoints{
|
||||
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
BasicAuth: []string{""},
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: "Wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, nil)
|
||||
_, err := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http", "https"},
|
||||
}
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}},
|
||||
"http": {Configuration: &configuration.EntryPoint{}},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
CertFile: localhostCert,
|
||||
KeyFile: localhostKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil {
|
||||
t.Fatalf("got error: %s", err)
|
||||
} else if mapEntryPoints["https"].certs.Get() == nil {
|
||||
t.Fatal("got error: https entryPoint must have TLS certificates.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureBackends(t *testing.T) {
|
||||
validMethod := "Drr"
|
||||
defaultMethod := "wrr"
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
lb *types.LoadBalancer
|
||||
expectedMethod string
|
||||
expectedStickiness *types.Stickiness
|
||||
}{
|
||||
{
|
||||
desc: "valid load balancer method with sticky enabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: validMethod,
|
||||
Stickiness: &types.Stickiness{},
|
||||
},
|
||||
expectedMethod: validMethod,
|
||||
expectedStickiness: &types.Stickiness{},
|
||||
},
|
||||
{
|
||||
desc: "valid load balancer method with sticky disabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: validMethod,
|
||||
Stickiness: nil,
|
||||
},
|
||||
expectedMethod: validMethod,
|
||||
},
|
||||
{
|
||||
desc: "invalid load balancer method with sticky enabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: "Invalid",
|
||||
Stickiness: &types.Stickiness{},
|
||||
},
|
||||
expectedMethod: defaultMethod,
|
||||
expectedStickiness: &types.Stickiness{},
|
||||
},
|
||||
{
|
||||
desc: "invalid load balancer method with sticky disabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: "Invalid",
|
||||
Stickiness: nil,
|
||||
},
|
||||
expectedMethod: defaultMethod,
|
||||
},
|
||||
{
|
||||
desc: "missing load balancer",
|
||||
lb: nil,
|
||||
expectedMethod: defaultMethod,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
backend := &types.Backend{
|
||||
LoadBalancer: test.lb,
|
||||
}
|
||||
|
||||
configureBackends(map[string]*types.Backend{
|
||||
"backend": backend,
|
||||
})
|
||||
|
||||
expected := types.LoadBalancer{
|
||||
Method: test.expectedMethod,
|
||||
Stickiness: test.expectedStickiness,
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, *backend.LoadBalancer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerEntryPointWhitelistConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entrypoint *configuration.EntryPoint
|
||||
expectMiddleware bool
|
||||
}{
|
||||
{
|
||||
desc: "no whitelist middleware if no config on entrypoint",
|
||||
entrypoint: &configuration.EntryPoint{
|
||||
Address: ":0",
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
expectMiddleware: false,
|
||||
},
|
||||
{
|
||||
desc: "whitelist middleware should be added if configured on entrypoint",
|
||||
entrypoint: &configuration.EntryPoint{
|
||||
Address: ":0",
|
||||
WhitelistSourceRange: []string{
|
||||
"127.0.0.1/32",
|
||||
},
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
expectMiddleware: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := Server{
|
||||
globalConfiguration: configuration.GlobalConfiguration{},
|
||||
metricsRegistry: metrics.NewVoidRegistry(),
|
||||
entryPoints: map[string]EntryPoint{
|
||||
"test": {
|
||||
Configuration: test.entrypoint,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv.serverEntryPoints = srv.buildEntryPoints()
|
||||
srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"])
|
||||
handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni)
|
||||
|
||||
found := false
|
||||
for _, handler := range handler.Handlers() {
|
||||
if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if found && !test.expectMiddleware {
|
||||
t.Error("ip whitelist middleware was installed even though it should not")
|
||||
}
|
||||
|
||||
if !found && test.expectMiddleware {
|
||||
t.Error("ip whitelist middleware was not installed even though it should have")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerResponseEmptyBackend(t *testing.T) {
|
||||
const requestPath = "/path"
|
||||
const routeRule = "Path:" + requestPath
|
||||
|
@ -962,157 +348,6 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReuseBackend(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http"},
|
||||
}
|
||||
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
}},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": th.BuildConfiguration(
|
||||
th.WithFrontends(
|
||||
th.WithFrontend("backend",
|
||||
th.WithFrontendName("frontend0"),
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))),
|
||||
th.WithFrontend("backend",
|
||||
th.WithFrontendName("frontend1"),
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")),
|
||||
th.WithBasicAuth("foo", "bar")),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServersNew(th.WithServerNew(testServer.URL))),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
serverEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("error loading config: %s", err)
|
||||
}
|
||||
|
||||
// Test that the /ok path returns a status 200.
|
||||
responseRecorderOk := &httptest.ResponseRecorder{}
|
||||
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
|
||||
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk)
|
||||
|
||||
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
|
||||
|
||||
// Test that the /unauthorized path returns a 401 because of
|
||||
// the basic authentication defined on the frontend.
|
||||
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
|
||||
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
|
||||
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
|
||||
}
|
||||
|
||||
func TestBuildRedirectHandler(t *testing.T) {
|
||||
srv := Server{
|
||||
globalConfiguration: configuration.GlobalConfiguration{},
|
||||
entryPoints: map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{Address: ":80"}},
|
||||
"https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
srcEntryPointName string
|
||||
url string
|
||||
entryPoint *configuration.EntryPoint
|
||||
redirect *types.Redirect
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
desc: "redirect regex",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo.com",
|
||||
redirect: &types.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foobar.com",
|
||||
},
|
||||
{
|
||||
desc: "redirect entry point",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo:80",
|
||||
redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foo:443",
|
||||
},
|
||||
{
|
||||
desc: "redirect entry point with regex (ignored)",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo.com:80",
|
||||
redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foo.com:443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := th.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Location", "fail")
|
||||
}))
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockContext struct {
|
||||
headers http.Header
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue