diff --git a/cmd/traefik/traefik.go b/cmd/traefik/traefik.go index 91a6d4d74..b0db4aabc 100644 --- a/cmd/traefik/traefik.go +++ b/cmd/traefik/traefik.go @@ -191,7 +191,7 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err // Entrypoints - serverEntryPointsTCP, err := server.NewTCPEntryPoints(staticConfiguration.EntryPoints) + serverEntryPointsTCP, err := server.NewTCPEntryPoints(staticConfiguration.EntryPoints, staticConfiguration.HostResolver) if err != nil { return nil, err } diff --git a/pkg/middlewares/snicheck/snicheck.go b/pkg/middlewares/snicheck/snicheck.go new file mode 100644 index 000000000..e18b605cb --- /dev/null +++ b/pkg/middlewares/snicheck/snicheck.go @@ -0,0 +1,107 @@ +package snicheck + +import ( + "net" + "net/http" + "strings" + + "github.com/traefik/traefik/v2/pkg/log" + "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" + traefiktls "github.com/traefik/traefik/v2/pkg/tls" +) + +// SNICheck is an HTTP handler that checks whether the TLS configuration for the server name is the same as for the host header. +type SNICheck struct { + next http.Handler + tlsOptionsForHost map[string]string +} + +// New creates a new SNICheck. +func New(tlsOptionsForHost map[string]string, next http.Handler) *SNICheck { + return &SNICheck{next: next, tlsOptionsForHost: tlsOptionsForHost} +} + +func (s SNICheck) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.TLS == nil { + s.next.ServeHTTP(rw, req) + return + } + + host := getHost(req) + serverName := strings.TrimSpace(req.TLS.ServerName) + + // Domain Fronting + if !strings.EqualFold(host, serverName) { + tlsOptionHeader := findTLSOptionName(s.tlsOptionsForHost, host, true) + tlsOptionSNI := findTLSOptionName(s.tlsOptionsForHost, serverName, false) + + if tlsOptionHeader != tlsOptionSNI { + log.WithoutContext(). + WithField("host", host). + WithField("req.Host", req.Host). + WithField("req.TLS.ServerName", req.TLS.ServerName). + Debugf("TLS options difference: SNI:%s, Header:%s", tlsOptionSNI, tlsOptionHeader) + http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest) + return + } + } + + s.next.ServeHTTP(rw, req) +} + +func getHost(req *http.Request) string { + h := requestdecorator.GetCNAMEFlatten(req.Context()) + if h != "" { + return h + } + + h = requestdecorator.GetCanonizedHost(req.Context()) + if h != "" { + return h + } + + host, _, err := net.SplitHostPort(req.Host) + if err != nil { + host = req.Host + } + + return strings.TrimSpace(host) +} + +func findTLSOptionName(tlsOptionsForHost map[string]string, host string, fqdn bool) string { + name := findTLSOptName(tlsOptionsForHost, host, fqdn) + if name != "" { + return name + } + + name = findTLSOptName(tlsOptionsForHost, strings.ToLower(host), fqdn) + if name != "" { + return name + } + + return traefiktls.DefaultTLSConfigName +} + +func findTLSOptName(tlsOptionsForHost map[string]string, host string, fqdn bool) string { + if tlsOptions, ok := tlsOptionsForHost[host]; ok { + return tlsOptions + } + + if !fqdn { + return "" + } + + if last := len(host) - 1; last >= 0 && host[last] == '.' { + if tlsOptions, ok := tlsOptionsForHost[host[:last]]; ok { + return tlsOptions + } + + return "" + } + + if tlsOptions, ok := tlsOptionsForHost[host+"."]; ok { + return tlsOptions + } + + return "" +} diff --git a/pkg/middlewares/snicheck/snicheck_test.go b/pkg/middlewares/snicheck/snicheck_test.go new file mode 100644 index 000000000..a9d57bdf6 --- /dev/null +++ b/pkg/middlewares/snicheck/snicheck_test.go @@ -0,0 +1,60 @@ +package snicheck + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSNICheck_ServeHTTP(t *testing.T) { + testCases := []struct { + desc string + tlsOptionsForHost map[string]string + host string + expected int + }{ + { + desc: "no TLS options", + expected: http.StatusOK, + }, + { + desc: "with TLS options", + tlsOptionsForHost: map[string]string{ + "example.com": "foo", + }, + expected: http.StatusOK, + }, + { + desc: "server name and host doesn't have the same TLS configuration", + tlsOptionsForHost: map[string]string{ + "example.com": "foo", + }, + host: "example.com", + expected: http.StatusMisdirectedRequest, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}) + + sniCheck := New(test.tlsOptionsForHost, next) + + req := httptest.NewRequest(http.MethodGet, "https://localhost", nil) + if test.host != "" { + req.Host = test.host + } + + recorder := httptest.NewRecorder() + + sniCheck.ServeHTTP(recorder, req) + + assert.Equal(t, test.expected, recorder.Code) + }) + } +} diff --git a/pkg/server/middleware/chainbuilder.go b/pkg/server/middleware/chainbuilder.go index 1f7a70ef3..418bdfd14 100644 --- a/pkg/server/middleware/chainbuilder.go +++ b/pkg/server/middleware/chainbuilder.go @@ -9,7 +9,6 @@ import ( "github.com/traefik/traefik/v2/pkg/metrics" "github.com/traefik/traefik/v2/pkg/middlewares/accesslog" metricsmiddleware "github.com/traefik/traefik/v2/pkg/middlewares/metrics" - "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" mTracing "github.com/traefik/traefik/v2/pkg/middlewares/tracing" "github.com/traefik/traefik/v2/pkg/tracing" "github.com/traefik/traefik/v2/pkg/tracing/jaeger" @@ -20,7 +19,6 @@ type ChainBuilder struct { metricsRegistry metrics.Registry accessLoggerMiddleware *accesslog.Handler tracer *tracing.Tracing - requestDecorator *requestdecorator.RequestDecorator } // NewChainBuilder Creates a new ChainBuilder. @@ -29,7 +27,6 @@ func NewChainBuilder(staticConfiguration static.Configuration, metricsRegistry m metricsRegistry: metricsRegistry, accessLoggerMiddleware: accessLoggerMiddleware, tracer: setupTracing(staticConfiguration.Tracing), - requestDecorator: requestdecorator.New(staticConfiguration.HostResolver), } } @@ -49,7 +46,7 @@ func (c *ChainBuilder) Build(ctx context.Context, entryPointName string) alice.C chain = chain.Append(metricsmiddleware.WrapEntryPointHandler(ctx, c.metricsRegistry, entryPointName)) } - return chain.Append(requestdecorator.WrapHandler(c.requestDecorator)) + return chain } // Close accessLogger and tracer. diff --git a/pkg/server/router/tcp/router.go b/pkg/server/router/tcp/router.go index 277080835..0a97a10b3 100644 --- a/pkg/server/router/tcp/router.go +++ b/pkg/server/router/tcp/router.go @@ -5,12 +5,11 @@ import ( "crypto/tls" "errors" "fmt" - "net" "net/http" - "strings" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" + "github.com/traefik/traefik/v2/pkg/middlewares/snicheck" "github.com/traefik/traefik/v2/pkg/rules" "github.com/traefik/traefik/v2/pkg/server/provider" tcpservice "github.com/traefik/traefik/v2/pkg/server/service/tcp" @@ -161,38 +160,7 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string } } - sniCheck := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.TLS == nil { - handlerHTTPS.ServeHTTP(rw, req) - return - } - - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host - } - - host = strings.TrimSpace(host) - serverName := strings.TrimSpace(req.TLS.ServerName) - - // Domain Fronting - if !strings.EqualFold(host, serverName) { - tlsOptionHeader := findTLSOptionName(tlsOptionsForHost, host, true) - tlsOptionSNI := findTLSOptionName(tlsOptionsForHost, serverName, false) - - if tlsOptionHeader != tlsOptionSNI { - log.WithoutContext(). - WithField("host", host). - WithField("req.Host", req.Host). - WithField("req.TLS.ServerName", req.TLS.ServerName). - Debugf("TLS options difference: SNI=%s, Header:%s", tlsOptionSNI, tlsOptionHeader) - http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest) - return - } - } - - handlerHTTPS.ServeHTTP(rw, req) - }) + sniCheck := snicheck.New(tlsOptionsForHost, handlerHTTPS) router.HTTPSHandler(sniCheck, defaultTLSConf) @@ -321,44 +289,3 @@ func (m *Manager) buildTCPHandler(ctx context.Context, router *runtime.TCPRouter return tcp.NewChain().Extend(*mHandler).Then(sHandler) } - -func findTLSOptionName(tlsOptionsForHost map[string]string, host string, fqdn bool) string { - name := findTLSOptName(tlsOptionsForHost, host, fqdn) - if name != "" { - return name - } - - name = findTLSOptName(tlsOptionsForHost, strings.ToLower(host), fqdn) - if name != "" { - return name - } - - return traefiktls.DefaultTLSConfigName -} - -func findTLSOptName(tlsOptionsForHost map[string]string, host string, fqdn bool) string { - tlsOptions, ok := tlsOptionsForHost[host] - if ok { - return tlsOptions - } - - if !fqdn { - return "" - } - - if last := len(host) - 1; last >= 0 && host[last] == '.' { - tlsOptions, ok = tlsOptionsForHost[host[:last]] - if ok { - return tlsOptions - } - - return "" - } - - tlsOptions, ok = tlsOptionsForHost[host+"."] - if ok { - return tlsOptions - } - - return "" -} diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index bf62b47eb..395720555 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/containous/alice" "github.com/pires/go-proxyproto" "github.com/sirupsen/logrus" "github.com/traefik/traefik/v2/pkg/config/static" @@ -18,9 +19,11 @@ import ( "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/middlewares" "github.com/traefik/traefik/v2/pkg/middlewares/forwardedheaders" + "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" "github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/server/router" "github.com/traefik/traefik/v2/pkg/tcp" + "github.com/traefik/traefik/v2/pkg/types" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -60,7 +63,7 @@ func (h *httpForwarder) Accept() (net.Conn, error) { type TCPEntryPoints map[string]*TCPEntryPoint // NewTCPEntryPoints creates a new TCPEntryPoints. -func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, error) { +func NewTCPEntryPoints(entryPointsConfig static.EntryPoints, hostResolverConfig *types.HostResolverConfig) (TCPEntryPoints, error) { serverEntryPointsTCP := make(TCPEntryPoints) for entryPointName, config := range entryPointsConfig { protocol, err := config.GetProtocol() @@ -74,7 +77,7 @@ func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, er ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName)) - serverEntryPointsTCP[entryPointName], err = NewTCPEntryPoint(ctx, config) + serverEntryPointsTCP[entryPointName], err = NewTCPEntryPoint(ctx, config, hostResolverConfig) if err != nil { return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err) } @@ -130,7 +133,7 @@ type TCPEntryPoint struct { } // NewTCPEntryPoint creates a new TCPEntryPoint. -func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*TCPEntryPoint, error) { +func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hostResolverConfig *types.HostResolverConfig) (*TCPEntryPoint, error) { tracker := newConnectionTracker() listener, err := buildListener(ctx, configuration) @@ -140,14 +143,16 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*T rt := &tcp.Router{} - httpServer, err := createHTTPServer(ctx, listener, configuration, true) + reqDecorator := requestdecorator.New(hostResolverConfig) + + httpServer, err := createHTTPServer(ctx, listener, configuration, true, reqDecorator) if err != nil { return nil, fmt.Errorf("error preparing httpServer: %w", err) } rt.HTTPForwarder(httpServer.Forwarder) - httpsServer, err := createHTTPServer(ctx, listener, configuration, false) + httpsServer, err := createHTTPServer(ctx, listener, configuration, false, reqDecorator) if err != nil { return nil, fmt.Errorf("error preparing httpsServer: %w", err) } @@ -500,16 +505,19 @@ type httpServer struct { Switcher *middlewares.HTTPHandlerSwitcher } -func createHTTPServer(ctx context.Context, ln net.Listener, configuration *static.EntryPoint, withH2c bool) (*httpServer, error) { +func createHTTPServer(ctx context.Context, ln net.Listener, configuration *static.EntryPoint, withH2c bool, reqDecorator *requestdecorator.RequestDecorator) (*httpServer, error) { httpSwitcher := middlewares.NewHandlerSwitcher(router.BuildDefaultHTTPRouter()) + next, err := alice.New(requestdecorator.WrapHandler(reqDecorator)).Then(httpSwitcher) + if err != nil { + return nil, err + } + var handler http.Handler - var err error handler, err = forwardedheaders.NewXForwarded( configuration.ForwardedHeaders.Insecure, configuration.ForwardedHeaders.TrustedIPs, - httpSwitcher) - + next) if err != nil { return nil, err } diff --git a/pkg/server/server_entrypoint_tcp_http3_test.go b/pkg/server/server_entrypoint_tcp_http3_test.go index bdd44e0cb..b0759ef48 100644 --- a/pkg/server/server_entrypoint_tcp_http3_test.go +++ b/pkg/server/server_entrypoint_tcp_http3_test.go @@ -91,7 +91,7 @@ func TestHTTP3AdvertisedPort(t *testing.T) { HTTP3: &static.HTTP3Config{ AdvertisedPort: 8080, }, - }) + }, nil) require.NoError(t, err) router := &tcp.Router{} diff --git a/pkg/server/server_entrypoint_tcp_test.go b/pkg/server/server_entrypoint_tcp_test.go index 6b461632c..ab49cd5a9 100644 --- a/pkg/server/server_entrypoint_tcp_test.go +++ b/pkg/server/server_entrypoint_tcp_test.go @@ -79,7 +79,7 @@ func testShutdown(t *testing.T, router *tcp.Router) { Address: "127.0.0.1:0", Transport: epConfig, ForwardedHeaders: &static.ForwardedHeaders{}, - }) + }, nil) require.NoError(t, err) conn, err := startEntrypoint(entryPoint, router) @@ -162,7 +162,7 @@ func TestReadTimeoutWithoutFirstByte(t *testing.T) { Address: ":0", Transport: epConfig, ForwardedHeaders: &static.ForwardedHeaders{}, - }) + }, nil) require.NoError(t, err) router := &tcp.Router{} @@ -198,7 +198,7 @@ func TestReadTimeoutWithFirstByte(t *testing.T) { Address: ":0", Transport: epConfig, ForwardedHeaders: &static.ForwardedHeaders{}, - }) + }, nil) require.NoError(t, err) router := &tcp.Router{}