Use CNAME for SNI check on host header

Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
Ludovic Fernandez 2022-02-14 17:18:08 +01:00 committed by GitHub
parent e97aa6515b
commit d9fbb5e25c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 192 additions and 93 deletions

View file

@ -191,7 +191,7 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err
// Entrypoints // Entrypoints
serverEntryPointsTCP, err := server.NewTCPEntryPoints(staticConfiguration.EntryPoints) serverEntryPointsTCP, err := server.NewTCPEntryPoints(staticConfiguration.EntryPoints, staticConfiguration.HostResolver)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,107 @@
package snicheck
import (
"net"
"net/http"
"strings"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
traefiktls "github.com/traefik/traefik/v2/pkg/tls"
)
// SNICheck is an HTTP handler that checks whether the TLS configuration for the server name is the same as for the host header.
type SNICheck struct {
next http.Handler
tlsOptionsForHost map[string]string
}
// New creates a new SNICheck.
func New(tlsOptionsForHost map[string]string, next http.Handler) *SNICheck {
return &SNICheck{next: next, tlsOptionsForHost: tlsOptionsForHost}
}
func (s SNICheck) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.TLS == nil {
s.next.ServeHTTP(rw, req)
return
}
host := getHost(req)
serverName := strings.TrimSpace(req.TLS.ServerName)
// Domain Fronting
if !strings.EqualFold(host, serverName) {
tlsOptionHeader := findTLSOptionName(s.tlsOptionsForHost, host, true)
tlsOptionSNI := findTLSOptionName(s.tlsOptionsForHost, serverName, false)
if tlsOptionHeader != tlsOptionSNI {
log.WithoutContext().
WithField("host", host).
WithField("req.Host", req.Host).
WithField("req.TLS.ServerName", req.TLS.ServerName).
Debugf("TLS options difference: SNI:%s, Header:%s", tlsOptionSNI, tlsOptionHeader)
http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest)
return
}
}
s.next.ServeHTTP(rw, req)
}
func getHost(req *http.Request) string {
h := requestdecorator.GetCNAMEFlatten(req.Context())
if h != "" {
return h
}
h = requestdecorator.GetCanonizedHost(req.Context())
if h != "" {
return h
}
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
}
return strings.TrimSpace(host)
}
func findTLSOptionName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
name := findTLSOptName(tlsOptionsForHost, host, fqdn)
if name != "" {
return name
}
name = findTLSOptName(tlsOptionsForHost, strings.ToLower(host), fqdn)
if name != "" {
return name
}
return traefiktls.DefaultTLSConfigName
}
func findTLSOptName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
if tlsOptions, ok := tlsOptionsForHost[host]; ok {
return tlsOptions
}
if !fqdn {
return ""
}
if last := len(host) - 1; last >= 0 && host[last] == '.' {
if tlsOptions, ok := tlsOptionsForHost[host[:last]]; ok {
return tlsOptions
}
return ""
}
if tlsOptions, ok := tlsOptionsForHost[host+"."]; ok {
return tlsOptions
}
return ""
}

View file

@ -0,0 +1,60 @@
package snicheck
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSNICheck_ServeHTTP(t *testing.T) {
testCases := []struct {
desc string
tlsOptionsForHost map[string]string
host string
expected int
}{
{
desc: "no TLS options",
expected: http.StatusOK,
},
{
desc: "with TLS options",
tlsOptionsForHost: map[string]string{
"example.com": "foo",
},
expected: http.StatusOK,
},
{
desc: "server name and host doesn't have the same TLS configuration",
tlsOptionsForHost: map[string]string{
"example.com": "foo",
},
host: "example.com",
expected: http.StatusMisdirectedRequest,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
sniCheck := New(test.tlsOptionsForHost, next)
req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
if test.host != "" {
req.Host = test.host
}
recorder := httptest.NewRecorder()
sniCheck.ServeHTTP(recorder, req)
assert.Equal(t, test.expected, recorder.Code)
})
}
}

View file

@ -9,7 +9,6 @@ import (
"github.com/traefik/traefik/v2/pkg/metrics" "github.com/traefik/traefik/v2/pkg/metrics"
"github.com/traefik/traefik/v2/pkg/middlewares/accesslog" "github.com/traefik/traefik/v2/pkg/middlewares/accesslog"
metricsmiddleware "github.com/traefik/traefik/v2/pkg/middlewares/metrics" metricsmiddleware "github.com/traefik/traefik/v2/pkg/middlewares/metrics"
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
mTracing "github.com/traefik/traefik/v2/pkg/middlewares/tracing" mTracing "github.com/traefik/traefik/v2/pkg/middlewares/tracing"
"github.com/traefik/traefik/v2/pkg/tracing" "github.com/traefik/traefik/v2/pkg/tracing"
"github.com/traefik/traefik/v2/pkg/tracing/jaeger" "github.com/traefik/traefik/v2/pkg/tracing/jaeger"
@ -20,7 +19,6 @@ type ChainBuilder struct {
metricsRegistry metrics.Registry metricsRegistry metrics.Registry
accessLoggerMiddleware *accesslog.Handler accessLoggerMiddleware *accesslog.Handler
tracer *tracing.Tracing tracer *tracing.Tracing
requestDecorator *requestdecorator.RequestDecorator
} }
// NewChainBuilder Creates a new ChainBuilder. // NewChainBuilder Creates a new ChainBuilder.
@ -29,7 +27,6 @@ func NewChainBuilder(staticConfiguration static.Configuration, metricsRegistry m
metricsRegistry: metricsRegistry, metricsRegistry: metricsRegistry,
accessLoggerMiddleware: accessLoggerMiddleware, accessLoggerMiddleware: accessLoggerMiddleware,
tracer: setupTracing(staticConfiguration.Tracing), tracer: setupTracing(staticConfiguration.Tracing),
requestDecorator: requestdecorator.New(staticConfiguration.HostResolver),
} }
} }
@ -49,7 +46,7 @@ func (c *ChainBuilder) Build(ctx context.Context, entryPointName string) alice.C
chain = chain.Append(metricsmiddleware.WrapEntryPointHandler(ctx, c.metricsRegistry, entryPointName)) chain = chain.Append(metricsmiddleware.WrapEntryPointHandler(ctx, c.metricsRegistry, entryPointName))
} }
return chain.Append(requestdecorator.WrapHandler(c.requestDecorator)) return chain
} }
// Close accessLogger and tracer. // Close accessLogger and tracer.

View file

@ -5,12 +5,11 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"strings"
"github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/config/runtime"
"github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares/snicheck"
"github.com/traefik/traefik/v2/pkg/rules" "github.com/traefik/traefik/v2/pkg/rules"
"github.com/traefik/traefik/v2/pkg/server/provider" "github.com/traefik/traefik/v2/pkg/server/provider"
tcpservice "github.com/traefik/traefik/v2/pkg/server/service/tcp" tcpservice "github.com/traefik/traefik/v2/pkg/server/service/tcp"
@ -161,38 +160,7 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
} }
} }
sniCheck := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { sniCheck := snicheck.New(tlsOptionsForHost, handlerHTTPS)
if req.TLS == nil {
handlerHTTPS.ServeHTTP(rw, req)
return
}
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
host = req.Host
}
host = strings.TrimSpace(host)
serverName := strings.TrimSpace(req.TLS.ServerName)
// Domain Fronting
if !strings.EqualFold(host, serverName) {
tlsOptionHeader := findTLSOptionName(tlsOptionsForHost, host, true)
tlsOptionSNI := findTLSOptionName(tlsOptionsForHost, serverName, false)
if tlsOptionHeader != tlsOptionSNI {
log.WithoutContext().
WithField("host", host).
WithField("req.Host", req.Host).
WithField("req.TLS.ServerName", req.TLS.ServerName).
Debugf("TLS options difference: SNI=%s, Header:%s", tlsOptionSNI, tlsOptionHeader)
http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest)
return
}
}
handlerHTTPS.ServeHTTP(rw, req)
})
router.HTTPSHandler(sniCheck, defaultTLSConf) router.HTTPSHandler(sniCheck, defaultTLSConf)
@ -321,44 +289,3 @@ func (m *Manager) buildTCPHandler(ctx context.Context, router *runtime.TCPRouter
return tcp.NewChain().Extend(*mHandler).Then(sHandler) return tcp.NewChain().Extend(*mHandler).Then(sHandler)
} }
func findTLSOptionName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
name := findTLSOptName(tlsOptionsForHost, host, fqdn)
if name != "" {
return name
}
name = findTLSOptName(tlsOptionsForHost, strings.ToLower(host), fqdn)
if name != "" {
return name
}
return traefiktls.DefaultTLSConfigName
}
func findTLSOptName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
tlsOptions, ok := tlsOptionsForHost[host]
if ok {
return tlsOptions
}
if !fqdn {
return ""
}
if last := len(host) - 1; last >= 0 && host[last] == '.' {
tlsOptions, ok = tlsOptionsForHost[host[:last]]
if ok {
return tlsOptions
}
return ""
}
tlsOptions, ok = tlsOptionsForHost[host+"."]
if ok {
return tlsOptions
}
return ""
}

View file

@ -11,6 +11,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/containous/alice"
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/traefik/traefik/v2/pkg/config/static" "github.com/traefik/traefik/v2/pkg/config/static"
@ -18,9 +19,11 @@ import (
"github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares" "github.com/traefik/traefik/v2/pkg/middlewares"
"github.com/traefik/traefik/v2/pkg/middlewares/forwardedheaders" "github.com/traefik/traefik/v2/pkg/middlewares/forwardedheaders"
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
"github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/safe"
"github.com/traefik/traefik/v2/pkg/server/router" "github.com/traefik/traefik/v2/pkg/server/router"
"github.com/traefik/traefik/v2/pkg/tcp" "github.com/traefik/traefik/v2/pkg/tcp"
"github.com/traefik/traefik/v2/pkg/types"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
) )
@ -60,7 +63,7 @@ func (h *httpForwarder) Accept() (net.Conn, error) {
type TCPEntryPoints map[string]*TCPEntryPoint type TCPEntryPoints map[string]*TCPEntryPoint
// NewTCPEntryPoints creates a new TCPEntryPoints. // NewTCPEntryPoints creates a new TCPEntryPoints.
func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, error) { func NewTCPEntryPoints(entryPointsConfig static.EntryPoints, hostResolverConfig *types.HostResolverConfig) (TCPEntryPoints, error) {
serverEntryPointsTCP := make(TCPEntryPoints) serverEntryPointsTCP := make(TCPEntryPoints)
for entryPointName, config := range entryPointsConfig { for entryPointName, config := range entryPointsConfig {
protocol, err := config.GetProtocol() protocol, err := config.GetProtocol()
@ -74,7 +77,7 @@ func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, er
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName)) ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
serverEntryPointsTCP[entryPointName], err = NewTCPEntryPoint(ctx, config) serverEntryPointsTCP[entryPointName], err = NewTCPEntryPoint(ctx, config, hostResolverConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err) return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err)
} }
@ -130,7 +133,7 @@ type TCPEntryPoint struct {
} }
// NewTCPEntryPoint creates a new TCPEntryPoint. // NewTCPEntryPoint creates a new TCPEntryPoint.
func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*TCPEntryPoint, error) { func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hostResolverConfig *types.HostResolverConfig) (*TCPEntryPoint, error) {
tracker := newConnectionTracker() tracker := newConnectionTracker()
listener, err := buildListener(ctx, configuration) listener, err := buildListener(ctx, configuration)
@ -140,14 +143,16 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*T
rt := &tcp.Router{} rt := &tcp.Router{}
httpServer, err := createHTTPServer(ctx, listener, configuration, true) reqDecorator := requestdecorator.New(hostResolverConfig)
httpServer, err := createHTTPServer(ctx, listener, configuration, true, reqDecorator)
if err != nil { if err != nil {
return nil, fmt.Errorf("error preparing httpServer: %w", err) return nil, fmt.Errorf("error preparing httpServer: %w", err)
} }
rt.HTTPForwarder(httpServer.Forwarder) rt.HTTPForwarder(httpServer.Forwarder)
httpsServer, err := createHTTPServer(ctx, listener, configuration, false) httpsServer, err := createHTTPServer(ctx, listener, configuration, false, reqDecorator)
if err != nil { if err != nil {
return nil, fmt.Errorf("error preparing httpsServer: %w", err) return nil, fmt.Errorf("error preparing httpsServer: %w", err)
} }
@ -500,16 +505,19 @@ type httpServer struct {
Switcher *middlewares.HTTPHandlerSwitcher Switcher *middlewares.HTTPHandlerSwitcher
} }
func createHTTPServer(ctx context.Context, ln net.Listener, configuration *static.EntryPoint, withH2c bool) (*httpServer, error) { func createHTTPServer(ctx context.Context, ln net.Listener, configuration *static.EntryPoint, withH2c bool, reqDecorator *requestdecorator.RequestDecorator) (*httpServer, error) {
httpSwitcher := middlewares.NewHandlerSwitcher(router.BuildDefaultHTTPRouter()) httpSwitcher := middlewares.NewHandlerSwitcher(router.BuildDefaultHTTPRouter())
next, err := alice.New(requestdecorator.WrapHandler(reqDecorator)).Then(httpSwitcher)
if err != nil {
return nil, err
}
var handler http.Handler var handler http.Handler
var err error
handler, err = forwardedheaders.NewXForwarded( handler, err = forwardedheaders.NewXForwarded(
configuration.ForwardedHeaders.Insecure, configuration.ForwardedHeaders.Insecure,
configuration.ForwardedHeaders.TrustedIPs, configuration.ForwardedHeaders.TrustedIPs,
httpSwitcher) next)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -91,7 +91,7 @@ func TestHTTP3AdvertisedPort(t *testing.T) {
HTTP3: &static.HTTP3Config{ HTTP3: &static.HTTP3Config{
AdvertisedPort: 8080, AdvertisedPort: 8080,
}, },
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
router := &tcp.Router{} router := &tcp.Router{}

View file

@ -79,7 +79,7 @@ func testShutdown(t *testing.T, router *tcp.Router) {
Address: "127.0.0.1:0", Address: "127.0.0.1:0",
Transport: epConfig, Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{}, ForwardedHeaders: &static.ForwardedHeaders{},
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
conn, err := startEntrypoint(entryPoint, router) conn, err := startEntrypoint(entryPoint, router)
@ -162,7 +162,7 @@ func TestReadTimeoutWithoutFirstByte(t *testing.T) {
Address: ":0", Address: ":0",
Transport: epConfig, Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{}, ForwardedHeaders: &static.ForwardedHeaders{},
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
router := &tcp.Router{} router := &tcp.Router{}
@ -198,7 +198,7 @@ func TestReadTimeoutWithFirstByte(t *testing.T) {
Address: ":0", Address: ":0",
Transport: epConfig, Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{}, ForwardedHeaders: &static.ForwardedHeaders{},
}) }, nil)
require.NoError(t, err) require.NoError(t, err)
router := &tcp.Router{} router := &tcp.Router{}