diff --git a/cmd/traefik/traefik.go b/cmd/traefik/traefik.go index 81bb76a15..089e2c338 100644 --- a/cmd/traefik/traefik.go +++ b/cmd/traefik/traefik.go @@ -94,8 +94,7 @@ Complete documentation is available at https://traefik.io`, Config: traefikConfiguration, DefaultPointersConfig: traefikPointersConfiguration, Run: func() error { - runCmd(&traefikConfiguration.Configuration, traefikConfiguration.ConfigFile) - return nil + return runCmd(&traefikConfiguration.Configuration, traefikConfiguration.ConfigFile) }, } @@ -192,7 +191,7 @@ Complete documentation is available at https://traefik.io`, os.Exit(0) } -func runCmd(staticConfiguration *static.Configuration, configFile string) { +func runCmd(staticConfiguration *static.Configuration, configFile string) error { configureLogging(staticConfiguration) if len(configFile) > 0 { @@ -247,8 +246,7 @@ func runCmd(staticConfiguration *static.Configuration, configFile string) { serverEntryPoint, err := server.NewEntryPoint(ctx, config) if err != nil { - logger.Errorf("Error while building entryPoint: %v", err) - continue + return fmt.Errorf("error while building entryPoint %s: %v", entryPointName, err) } serverEntryPoint.RouteAppenderFactory = router.NewRouteAppenderFactory(*staticConfiguration, entryPointName, acmeProvider) @@ -315,6 +313,7 @@ func runCmd(staticConfiguration *static.Configuration, configFile string) { svr.Wait() log.WithoutContext().Info("Shutting down") logrus.Exit(0) + return nil } func configureLogging(staticConfiguration *static.Configuration) { diff --git a/config/static/entrypoints.go b/config/static/entrypoints.go index b9d4b772b..190456556 100644 --- a/config/static/entrypoints.go +++ b/config/static/entrypoints.go @@ -10,10 +10,17 @@ import ( // EntryPoint holds the entry point configuration. type EntryPoint struct { - Address string - Transport *EntryPointsTransport - TLS *tls.TLS - ProxyProtocol *ProxyProtocol + Address string + Transport *EntryPointsTransport + TLS *tls.TLS + ProxyProtocol *ProxyProtocol + ForwardedHeaders *ForwardedHeaders +} + +// ForwardedHeaders Trust client forwarding headers. +type ForwardedHeaders struct { + Insecure bool + TrustedIPs []string } // ProxyProtocol contains Proxy-Protocol configuration. @@ -64,9 +71,10 @@ func (ep *EntryPoints) Set(value string) error { } (*ep)[result["name"]] = &EntryPoint{ - Address: result["address"], - TLS: configTLS, - ProxyProtocol: makeEntryPointProxyProtocol(result), + Address: result["address"], + TLS: configTLS, + ProxyProtocol: makeEntryPointProxyProtocol(result), + ForwardedHeaders: makeEntryPointForwardedHeaders(result), } return nil @@ -167,3 +175,15 @@ func toBool(conf map[string]string, key string) bool { } return false } + +func makeEntryPointForwardedHeaders(result map[string]string) *ForwardedHeaders { + forwardedHeaders := &ForwardedHeaders{} + forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure") + + fhTrustedIPs := result["forwardedheaders_trustedips"] + if len(fhTrustedIPs) > 0 { + forwardedHeaders.TrustedIPs = strings.Split(fhTrustedIPs, ",") + } + + return forwardedHeaders +} diff --git a/config/static/entrypoints_test.go b/config/static/entrypoints_test.go index 8efd4f480..50f5d287e 100644 --- a/config/static/entrypoints_test.go +++ b/config/static/entrypoints_test.go @@ -206,6 +206,7 @@ func TestEntryPoints_Set(t *testing.T) { Insecure: false, TrustedIPs: []string{"192.168.0.1"}, }, + ForwardedHeaders: &ForwardedHeaders{}, // FIXME Test ServersTransport }, }, @@ -234,6 +235,7 @@ func TestEntryPoints_Set(t *testing.T) { Insecure: false, TrustedIPs: []string{"192.168.0.1"}, }, + ForwardedHeaders: &ForwardedHeaders{}, // FIXME Test ServersTransport }, }, @@ -241,14 +243,17 @@ func TestEntryPoints_Set(t *testing.T) { name: "default", expression: "Name:foo", expectedEntryPointName: "foo", - expectedEntryPoint: &EntryPoint{}, + expectedEntryPoint: &EntryPoint{ + ForwardedHeaders: &ForwardedHeaders{}, + }, }, { name: "ProxyProtocol insecure true", expression: "Name:foo ProxyProtocol.insecure:true", expectedEntryPointName: "foo", expectedEntryPoint: &EntryPoint{ - ProxyProtocol: &ProxyProtocol{Insecure: true}, + ProxyProtocol: &ProxyProtocol{Insecure: true}, + ForwardedHeaders: &ForwardedHeaders{}, }, }, { @@ -256,7 +261,8 @@ func TestEntryPoints_Set(t *testing.T) { expression: "Name:foo ProxyProtocol.insecure:false", expectedEntryPointName: "foo", expectedEntryPoint: &EntryPoint{ - ProxyProtocol: &ProxyProtocol{}, + ProxyProtocol: &ProxyProtocol{}, + ForwardedHeaders: &ForwardedHeaders{}, }, }, { @@ -267,6 +273,7 @@ func TestEntryPoints_Set(t *testing.T) { ProxyProtocol: &ProxyProtocol{ TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"}, }, + ForwardedHeaders: &ForwardedHeaders{}, }, }, } diff --git a/config/static/static_config.go b/config/static/static_config.go index 5a606ecb8..9d38b70e4 100644 --- a/config/static/static_config.go +++ b/config/static/static_config.go @@ -184,7 +184,10 @@ func (c *Configuration) SetEffectiveConfiguration(configFile string) { entryPoint.Transport.RespondingTimeouts = &RespondingTimeouts{ IdleTimeout: parse.Duration(DefaultIdleTimeout), } + } + if entryPoint.ForwardedHeaders == nil { + entryPoint.ForwardedHeaders = &ForwardedHeaders{} } } diff --git a/middlewares/forwardedheaders/forwarded_header.go b/middlewares/forwardedheaders/forwarded_header.go new file mode 100644 index 000000000..3f0667e7d --- /dev/null +++ b/middlewares/forwardedheaders/forwarded_header.go @@ -0,0 +1,51 @@ +package forwardedheaders + +import ( + "net/http" + + "github.com/containous/traefik/ip" + "github.com/vulcand/oxy/forward" + "github.com/vulcand/oxy/utils" +) + +// XForwarded filter for XForwarded headers. +type XForwarded struct { + insecure bool + trustedIps []string + ipChecker *ip.Checker + next http.Handler +} + +// NewXForwarded creates a new XForwarded. +func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) { + var ipChecker *ip.Checker + if len(trustedIps) > 0 { + var err error + ipChecker, err = ip.NewChecker(trustedIps) + if err != nil { + return nil, err + } + } + + return &XForwarded{ + insecure: insecure, + trustedIps: trustedIps, + ipChecker: ipChecker, + next: next, + }, nil +} + +func (x *XForwarded) isTrustedIP(ip string) bool { + if x.ipChecker == nil { + return false + } + return x.ipChecker.IsAuthorized(ip) == nil +} + +func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !x.insecure && !x.isTrustedIP(r.RemoteAddr) { + utils.RemoveHeaders(r.Header, forward.XHeaders...) + } + + x.next.ServeHTTP(w, r) +} diff --git a/middlewares/forwardedheaders/forwarded_header_test.go b/middlewares/forwardedheaders/forwarded_header_test.go new file mode 100644 index 000000000..1062c8ea1 --- /dev/null +++ b/middlewares/forwardedheaders/forwarded_header_test.go @@ -0,0 +1,128 @@ +package forwardedheaders + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServeHTTP(t *testing.T) { + testCases := []struct { + desc string + insecure bool + trustedIps []string + incomingHeaders map[string]string + remoteAddr string + expectedHeaders map[string]string + }{ + { + desc: "all Empty", + insecure: true, + trustedIps: nil, + remoteAddr: "", + incomingHeaders: map[string]string{}, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "", + }, + }, + { + desc: "insecure true with incoming X-Forwarded-For", + insecure: true, + trustedIps: nil, + remoteAddr: "", + incomingHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + }, + { + desc: "insecure false with incoming X-Forwarded-For", + insecure: false, + trustedIps: nil, + remoteAddr: "", + incomingHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "", + }, + }, + { + desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips", + insecure: false, + trustedIps: []string{"10.0.1.100"}, + remoteAddr: "10.0.1.100:80", + incomingHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + }, + { + desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips", + insecure: false, + trustedIps: []string{"10.0.1.100"}, + remoteAddr: "10.0.1.101:80", + incomingHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "", + }, + }, + { + desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR", + insecure: false, + trustedIps: []string{"1.2.3.4/24"}, + remoteAddr: "1.2.3.156:80", + incomingHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + }, + { + desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR", + insecure: false, + trustedIps: []string{"1.2.3.4/24"}, + remoteAddr: "10.0.1.101:80", + incomingHeaders: map[string]string{ + "X-Forwarded-for": "10.0.1.0, 10.0.1.12", + }, + expectedHeaders: map[string]string{ + "X-Forwarded-for": "", + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + req.RemoteAddr = test.remoteAddr + + for k, v := range test.incomingHeaders { + req.Header.Set(k, v) + } + + m, err := NewXForwarded(test.insecure, test.trustedIps, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + require.NoError(t, err) + + m.ServeHTTP(nil, req) + + for k, v := range test.expectedHeaders { + assert.Equal(t, v, req.Header.Get(k)) + } + }) + } +} diff --git a/server/server.go b/server/server.go index c0768be29..33ceb36a9 100644 --- a/server/server.go +++ b/server/server.go @@ -218,7 +218,7 @@ func (s *Server) startHTTPServers() { // Use an empty configuration in order to initialize the default handlers with internal routes handlers := s.applyConfiguration(context.Background(), config.Configuration{}) for entryPointName, handler := range handlers { - s.entryPoints[entryPointName].httpRouter.UpdateHandler(handler) + s.entryPoints[entryPointName].switcher.UpdateHandler(handler) } for entryPointName, entryPoint := range s.entryPoints { diff --git a/server/server_configuration.go b/server/server_configuration.go index d1492bba8..bebe64c11 100644 --- a/server/server_configuration.go +++ b/server/server_configuration.go @@ -45,7 +45,7 @@ func (s *Server) loadConfiguration(configMsg config.Message) { s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix())) for entryPointName, handler := range handlers { - s.entryPoints[entryPointName].httpRouter.UpdateHandler(handler) + s.entryPoints[entryPointName].switcher.UpdateHandler(handler) } for entryPointName, entryPoint := range s.entryPoints { diff --git a/server/server_entrypoint.go b/server/server_entrypoint.go index dae1785d2..98a21fbe3 100644 --- a/server/server_entrypoint.go +++ b/server/server_entrypoint.go @@ -17,6 +17,7 @@ import ( "github.com/containous/traefik/ip" "github.com/containous/traefik/log" "github.com/containous/traefik/middlewares" + "github.com/containous/traefik/middlewares/forwardedheaders" "github.com/containous/traefik/old/configuration" traefiktls "github.com/containous/traefik/tls" "github.com/containous/traefik/tls/generate" @@ -30,15 +31,22 @@ type EntryPoints map[string]*EntryPoint // NewEntryPoint creates a new EntryPoint func NewEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*EntryPoint, error) { - logger := log.FromContext(ctx) var err error - router := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter()) + switcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter()) + handler, err := forwardedheaders.NewXForwarded( + configuration.ForwardedHeaders.Insecure, + configuration.ForwardedHeaders.TrustedIPs, + switcher) + if err != nil { + return nil, err + } + tracker := newHijackConnectionTracker() listener, err := buildListener(ctx, configuration) if err != nil { - logger.Fatalf("Error preparing server: %v", err) + return nil, fmt.Errorf("error preparing server: %v", err) } var tlsConfig *tls.Config @@ -56,11 +64,11 @@ func NewEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*Entr } entryPoint := &EntryPoint{ - httpRouter: router, + switcher: switcher, transportConfiguration: configuration.Transport, hijackConnectionTracker: tracker, listener: listener, - httpServer: buildServer(ctx, configuration, tlsConfig, router, tracker), + httpServer: buildServer(ctx, configuration, tlsConfig, handler, tracker), Certs: certificateStore, } @@ -76,7 +84,7 @@ type EntryPoint struct { RouteAppenderFactory RouteAppenderFactory httpServer *h2c.Server listener net.Listener - httpRouter *middlewares.HandlerSwitcher + switcher *middlewares.HandlerSwitcher Certs *traefiktls.CertificateStore OnDemandListener func(string) (*tls.Certificate, error) TLSALPNGetter func(string) (*tls.Certificate, error)