Support ALPN for TCP + TLS routers

This commit is contained in:
Dmitry Sharshakov 2022-07-07 17:58:09 +03:00 committed by GitHub
parent aff334ffb4
commit 4dc379c601
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 228 additions and 34 deletions

View file

@ -840,10 +840,11 @@ If the rule is verified, the router becomes active, calls middlewares, and then
The table below lists all the available matchers: The table below lists all the available matchers:
| Rule | Description | | Rule | Description |
|---------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------| |---------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------|
| ```HostSNI(`domain-1`, ...)``` | Check if the Server Name Indication corresponds to the given `domains`. | | ```HostSNI(`domain-1`, ...)``` | Checks if the Server Name Indication corresponds to the given `domains`. |
| ```HostSNIRegexp(`example.com`, `{subdomain:[a-z]+}.example.com`, ...)``` | Check if the Server Name Indication matches the given regular expressions. See "Regexp Syntax" below. | | ```HostSNIRegexp(`example.com`, `{subdomain:[a-z]+}.example.com`, ...)``` | Checks if the Server Name Indication matches the given regular expressions. See "Regexp Syntax" below. |
| ```ClientIP(`10.0.0.0/16`, `::1`)``` | Check if the request client IP is one of the given IP/CIDR. It accepts IPv4, IPv6 and CIDR formats. | | ```ClientIP(`10.0.0.0/16`, `::1`)``` | Checks if the connection client IP is one of the given IP/CIDR. It accepts IPv4, IPv6 and CIDR formats. |
| ```ALPN(`mqtt`, `h2c`)``` | Checks if any of the connection ALPN protocols is one of the given protocols. |
!!! important "Non-ASCII Domain Names" !!! important "Non-ASCII Domain Names"
@ -879,6 +880,13 @@ The table below lists all the available matchers:
The rule is evaluated "before" any middleware has the opportunity to work, and "before" the request is forwarded to the service. The rule is evaluated "before" any middleware has the opportunity to work, and "before" the request is forwarded to the service.
!!! important "ALPN ACME-TLS/1"
It would be a security issue to let a user-defined router catch the response to
an ACME TLS challenge previously initiated by Traefik.
For this reason, the `ALPN` matcher is not allowed to match the `ACME-TLS/1`
protocol, and Traefik returns an error if this is attempted.
### Priority ### Priority
To avoid path overlap, routes are sorted, by default, in descending order using rules length. To avoid path overlap, routes are sorted, by default, in descending order using rules length.

View file

@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/traefik/traefik/v2/pkg/ip" "github.com/traefik/traefik/v2/pkg/ip"
"github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/rules" "github.com/traefik/traefik/v2/pkg/rules"
@ -22,6 +23,7 @@ var tcpFuncs = map[string]func(*matchersTree, ...string) error{
"HostSNI": hostSNI, "HostSNI": hostSNI,
"HostSNIRegexp": hostSNIRegexp, "HostSNIRegexp": hostSNIRegexp,
"ClientIP": clientIP, "ClientIP": clientIP,
"ALPN": alpn,
} }
// ParseHostSNI extracts the HostSNIs declared in a rule. // ParseHostSNI extracts the HostSNIs declared in a rule.
@ -54,10 +56,11 @@ func ParseHostSNI(rule string) ([]string, error) {
type ConnData struct { type ConnData struct {
serverName string serverName string
remoteIP string remoteIP string
alpnProtos []string
} }
// NewConnData builds a connData struct from the given parameters. // NewConnData builds a connData struct from the given parameters.
func NewConnData(serverName string, conn tcp.WriteCloser) (ConnData, error) { func NewConnData(serverName string, conn tcp.WriteCloser, alpnProtos []string) (ConnData, error) {
remoteIP, _, err := net.SplitHostPort(conn.RemoteAddr().String()) remoteIP, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil { if err != nil {
return ConnData{}, fmt.Errorf("error while parsing remote address %q: %w", conn.RemoteAddr().String(), err) return ConnData{}, fmt.Errorf("error while parsing remote address %q: %w", conn.RemoteAddr().String(), err)
@ -71,6 +74,7 @@ func NewConnData(serverName string, conn tcp.WriteCloser) (ConnData, error) {
return ConnData{ return ConnData{
serverName: types.CanonicalDomain(serverName), serverName: types.CanonicalDomain(serverName),
remoteIP: remoteIP, remoteIP: remoteIP,
alpnProtos: alpnProtos,
}, nil }, nil
} }
@ -284,6 +288,33 @@ func clientIP(tree *matchersTree, clientIPs ...string) error {
return nil return nil
} }
// alpn checks if any of the connection ALPN protocols matches one of the matcher protocols.
func alpn(tree *matchersTree, protos ...string) error {
if len(protos) == 0 {
return errors.New("empty value for \"ALPN\" matcher is not allowed")
}
for _, proto := range protos {
if proto == tlsalpn01.ACMETLS1Protocol {
return fmt.Errorf("invalid protocol value for \"ALPN\" matcher, %q is not allowed", proto)
}
}
tree.matcher = func(meta ConnData) bool {
for _, proto := range meta.alpnProtos {
for _, filter := range protos {
if proto == filter {
return true
}
}
}
return false
}
return nil
}
var almostFQDN = regexp.MustCompile(`^[[:alnum:]\.-]+$`) var almostFQDN = regexp.MustCompile(`^[[:alnum:]\.-]+$`)
// hostSNI checks if the SNI Host of the connection match the matcher host. // hostSNI checks if the SNI Host of the connection match the matcher host.

View file

@ -1,10 +1,12 @@
package tcp package tcp
import ( import (
"fmt"
"net" "net"
"testing" "testing"
"time" "time"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/traefik/traefik/v2/pkg/tcp" "github.com/traefik/traefik/v2/pkg/tcp"
@ -58,6 +60,7 @@ func Test_addTCPRoute(t *testing.T) {
rule string rule string
serverName string serverName string
remoteAddr string remoteAddr string
protos []string
routeErr bool routeErr bool
matchErr bool matchErr bool
}{ }{
@ -436,6 +439,66 @@ func Test_addTCPRoute(t *testing.T) {
serverName: "bar", serverName: "bar",
remoteAddr: "10.0.0.1:80", remoteAddr: "10.0.0.1:80",
}, },
{
desc: "Invalid ALPN rule matching ACME-TLS/1",
rule: fmt.Sprintf("ALPN(`%s`)", tlsalpn01.ACMETLS1Protocol),
protos: []string{"foo"},
routeErr: true,
},
{
desc: "Valid ALPN rule matching single protocol",
rule: "ALPN(`foo`)",
protos: []string{"foo"},
},
{
desc: "Valid ALPN rule matching ACME-TLS/1 protocol",
rule: "ALPN(`foo`)",
protos: []string{tlsalpn01.ACMETLS1Protocol},
matchErr: true,
},
{
desc: "Valid ALPN rule not matching single protocol",
rule: "ALPN(`foo`)",
protos: []string{"bar"},
matchErr: true,
},
{
desc: "Valid alternative case ALPN rule matching single protocol without another being supported",
rule: "ALPN(`foo`) && !alpn(`h2`)",
protos: []string{"foo", "bar"},
},
{
desc: "Valid alternative case ALPN rule not matching single protocol because of another being supported",
rule: "ALPN(`foo`) && !alpn(`h2`)",
protos: []string{"foo", "h2", "bar"},
matchErr: true,
},
{
desc: "Valid complex alternative case ALPN and HostSNI rule",
rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))",
protos: []string{"foo", "bar"},
serverName: "foo",
},
{
desc: "Valid complex alternative case ALPN and HostSNI rule not matching by SNI",
rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))",
protos: []string{"foo", "bar", "h2"},
serverName: "bar",
matchErr: true,
},
{
desc: "Valid complex alternative case ALPN and HostSNI rule matching by ALPN",
rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))",
protos: []string{"foo", "bar"},
serverName: "bar",
},
{
desc: "Valid complex alternative case ALPN and HostSNI rule not matching by protos",
rule: "ALPN(`foo`) && (!alpn(`h2`) || hostsni(`foo`))",
protos: []string{"h2", "bar"},
serverName: "bar",
matchErr: true,
},
} }
for _, test := range testCases { for _, test := range testCases {
@ -471,7 +534,7 @@ func Test_addTCPRoute(t *testing.T) {
remoteAddr: fakeAddr{addr: addr}, remoteAddr: fakeAddr{addr: addr},
} }
connData, err := NewConnData(test.serverName, conn) connData, err := NewConnData(test.serverName, conn, test.protos)
require.NoError(t, err) require.NoError(t, err)
matchingHandler, _ := router.Match(connData) matchingHandler, _ := router.Match(connData)
@ -918,6 +981,75 @@ func Test_ClientIP(t *testing.T) {
} }
} }
func Test_ALPN(t *testing.T) {
testCases := []struct {
desc string
ruleALPNProtos []string
connProto string
buildErr bool
matchErr bool
}{
{
desc: "Empty",
buildErr: true,
},
{
desc: "ACME TLS proto",
ruleALPNProtos: []string{tlsalpn01.ACMETLS1Protocol},
buildErr: true,
},
{
desc: "Not matching empty proto",
ruleALPNProtos: []string{"h2"},
matchErr: true,
},
{
desc: "Not matching ALPN",
ruleALPNProtos: []string{"h2"},
connProto: "mqtt",
matchErr: true,
},
{
desc: "Matching ALPN",
ruleALPNProtos: []string{"h2"},
connProto: "h2",
},
{
desc: "Not matching multiple ALPNs",
ruleALPNProtos: []string{"h2", "mqtt"},
connProto: "h2c",
matchErr: true,
},
{
desc: "Matching multiple ALPNs",
ruleALPNProtos: []string{"h2", "h2c", "mqtt"},
connProto: "h2c",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
matchersTree := &matchersTree{}
err := alpn(matchersTree, test.ruleALPNProtos...)
if test.buildErr {
require.Error(t, err)
return
}
require.NoError(t, err)
meta := ConnData{
alpnProtos: []string{test.connProto},
}
assert.Equal(t, test.matchErr, !matchersTree.match(meta))
})
}
}
func Test_Priority(t *testing.T) { func Test_Priority(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string

View file

@ -83,10 +83,10 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
// Handling Non-TLS TCP connection early if there is neither HTTP(S) nor TLS // Handling Non-TLS TCP connection early if there is neither HTTP(S) nor TLS
// routers on the entryPoint, and if there is at least one non-TLS TCP router. // routers on the entryPoint, and if there is at least one non-TLS TCP router.
// In the case of a non-TLS TCP client (that does not "send" first), we would // In the case of a non-TLS TCP client (that does not "send" first), we would
// block forever on clientHelloServerName, which is why we want to detect and // block forever on clientHelloInfo, which is why we want to detect and
// handle that case first and foremost. // handle that case first and foremost.
if r.muxerTCP.HasRoutes() && !r.muxerTCPTLS.HasRoutes() && !r.muxerHTTPS.HasRoutes() { if r.muxerTCP.HasRoutes() && !r.muxerTCPTLS.HasRoutes() && !r.muxerHTTPS.HasRoutes() {
connData, err := tcpmuxer.NewConnData("", conn) connData, err := tcpmuxer.NewConnData("", conn, nil)
if err != nil { if err != nil {
log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err) log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err)
conn.Close() conn.Close()
@ -108,7 +108,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
// FIXME -- Check if ProxyProtocol changes the first bytes of the request // FIXME -- Check if ProxyProtocol changes the first bytes of the request
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
serverName, tls, peeked, err := clientHelloServerName(br) hello, err := clientHelloInfo(br)
if err != nil { if err != nil {
conn.Close() conn.Close()
return return
@ -125,20 +125,20 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
log.WithoutContext().Errorf("Error while setting write deadline: %v", err) log.WithoutContext().Errorf("Error while setting write deadline: %v", err)
} }
connData, err := tcpmuxer.NewConnData(serverName, conn) connData, err := tcpmuxer.NewConnData(hello.serverName, conn, hello.protos)
if err != nil { if err != nil {
log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err) log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err)
conn.Close() conn.Close()
return return
} }
if !tls { if !hello.isTLS {
handler, _ := r.muxerTCP.Match(connData) handler, _ := r.muxerTCP.Match(connData)
switch { switch {
case handler != nil: case handler != nil:
handler.ServeTCP(r.GetConn(conn, peeked)) handler.ServeTCP(r.GetConn(conn, hello.peeked))
case r.httpForwarder != nil: case r.httpForwarder != nil:
r.httpForwarder.ServeTCP(r.GetConn(conn, peeked)) r.httpForwarder.ServeTCP(r.GetConn(conn, hello.peeked))
default: default:
conn.Close() conn.Close()
} }
@ -155,14 +155,14 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
// In order not to depart from the behavior in 2.6, we only allow an HTTPS router // In order not to depart from the behavior in 2.6, we only allow an HTTPS router
// to take precedence over a TCP-TLS router if it is _not_ an HostSNI(*) router (so // to take precedence over a TCP-TLS router if it is _not_ an HostSNI(*) router (so
// basically any router that has a specific HostSNI based rule). // basically any router that has a specific HostSNI based rule).
handlerHTTPS.ServeTCP(r.GetConn(conn, peeked)) handlerHTTPS.ServeTCP(r.GetConn(conn, hello.peeked))
return return
} }
// Contains also TCP TLS passthrough routes. // Contains also TCP TLS passthrough routes.
handlerTCPTLS, catchAllTCPTLS := r.muxerTCPTLS.Match(connData) handlerTCPTLS, catchAllTCPTLS := r.muxerTCPTLS.Match(connData)
if handlerTCPTLS != nil && !catchAllTCPTLS { if handlerTCPTLS != nil && !catchAllTCPTLS {
handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked)) handlerTCPTLS.ServeTCP(r.GetConn(conn, hello.peeked))
return return
} }
@ -170,19 +170,19 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
// We end up here for e.g. an HTTPS router that only has a PathPrefix rule, // We end up here for e.g. an HTTPS router that only has a PathPrefix rule,
// which under the scenes is counted as an HostSNI(*) rule. // which under the scenes is counted as an HostSNI(*) rule.
if handlerHTTPS != nil { if handlerHTTPS != nil {
handlerHTTPS.ServeTCP(r.GetConn(conn, peeked)) handlerHTTPS.ServeTCP(r.GetConn(conn, hello.peeked))
return return
} }
// Fallback on TCP TLS catchAll. // Fallback on TCP TLS catchAll.
if handlerTCPTLS != nil { if handlerTCPTLS != nil {
handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked)) handlerTCPTLS.ServeTCP(r.GetConn(conn, hello.peeked))
return return
} }
// needed to handle 404s for HTTPS, as well as all non-Host (e.g. PathPrefix) matches. // needed to handle 404s for HTTPS, as well as all non-Host (e.g. PathPrefix) matches.
if r.httpsForwarder != nil { if r.httpsForwarder != nil {
r.httpsForwarder.ServeTCP(r.GetConn(conn, peeked)) r.httpsForwarder.ServeTCP(r.GetConn(conn, hello.peeked))
return return
} }
@ -300,18 +300,24 @@ func (c *Conn) Read(p []byte) (n int, err error) {
return c.WriteCloser.Read(p) return c.WriteCloser.Read(p)
} }
// clientHelloServerName returns the SNI server name inside the TLS ClientHello, type clientHello struct {
serverName string // SNI server name
protos []string // ALPN protocols list
isTLS bool // whether we are a TLS handshake
peeked string // the bytes peeked from the hello while getting the info
}
// clientHelloInfo returns various data from the clientHello handshake,
// without consuming any bytes from br. // without consuming any bytes from br.
// On any error, the empty string is returned. // It returns an error if it can't peek the first byte from the connection.
func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { func clientHelloInfo(br *bufio.Reader) (*clientHello, error) {
hdr, err := br.Peek(1) hdr, err := br.Peek(1)
if err != nil { if err != nil {
var opErr *net.OpError var opErr *net.OpError
if !errors.Is(err, io.EOF) && (!errors.As(err, &opErr) || opErr.Timeout()) { if !errors.Is(err, io.EOF) && (!errors.As(err, &opErr) || opErr.Timeout()) {
log.WithoutContext().Errorf("Error while Peeking first byte: %s", err) log.WithoutContext().Errorf("Error while Peeking first byte: %s", err)
} }
return nil, err
return "", false, "", err
} }
// No valid TLS record has a type of 0x80, however SSLv2 handshakes // No valid TLS record has a type of 0x80, however SSLv2 handshakes
@ -323,16 +329,23 @@ func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) {
if hdr[0] != recordTypeHandshake { if hdr[0] != recordTypeHandshake {
if hdr[0] == recordTypeSSLv2 { if hdr[0] == recordTypeSSLv2 {
// we consider SSLv2 as TLS and it will be refused by real TLS handshake. // we consider SSLv2 as TLS and it will be refused by real TLS handshake.
return "", true, getPeeked(br), nil return &clientHello{
isTLS: true,
peeked: getPeeked(br),
}, nil
} }
return "", false, getPeeked(br), nil // Not TLS. return &clientHello{
peeked: getPeeked(br),
}, nil // Not TLS.
} }
const recordHeaderLen = 5 const recordHeaderLen = 5
hdr, err = br.Peek(recordHeaderLen) hdr, err = br.Peek(recordHeaderLen)
if err != nil { if err != nil {
log.Errorf("Error while Peeking hello: %s", err) log.Errorf("Error while Peeking hello: %s", err)
return "", false, getPeeked(br), nil return &clientHello{
peeked: getPeeked(br),
}, nil
} }
recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
@ -344,19 +357,29 @@ func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) {
helloBytes, err := br.Peek(recordHeaderLen + recLen) helloBytes, err := br.Peek(recordHeaderLen + recLen)
if err != nil { if err != nil {
log.Errorf("Error while Hello: %s", err) log.Errorf("Error while Hello: %s", err)
return "", true, getPeeked(br), nil return &clientHello{
isTLS: true,
peeked: getPeeked(br),
}, nil
} }
sni := "" sni := ""
server := tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ var protos []string
server := tls.Server(helloSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
sni = hello.ServerName sni = hello.ServerName
protos = hello.SupportedProtos
return nil, nil return nil, nil
}, },
}) })
_ = server.Handshake() _ = server.Handshake()
return sni, true, getPeeked(br), nil return &clientHello{
serverName: sni,
isTLS: true,
peeked: getPeeked(br),
protos: protos,
}, nil
} }
func getPeeked(br *bufio.Reader) string { func getPeeked(br *bufio.Reader) string {
@ -368,15 +391,15 @@ func getPeeked(br *bufio.Reader) string {
return string(peeked) return string(peeked)
} }
// sniSniffConn is a net.Conn that reads from r, fails on Writes, // helloSniffConn is a net.Conn that reads from r, fails on Writes,
// and crashes otherwise. // and crashes otherwise.
type sniSniffConn struct { type helloSniffConn struct {
r io.Reader r io.Reader
net.Conn // nil; crash on any unexpected use net.Conn // nil; crash on any unexpected use
} }
// Read reads from the underlying reader. // Read reads from the underlying reader.
func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } func (c helloSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
// Write crashes all the time. // Write crashes all the time.
func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } func (helloSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }