108 lines
2.5 KiB
Go
108 lines
2.5 KiB
Go
|
package snicheck
|
||
|
|
||
|
import (
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/traefik/traefik/v2/pkg/log"
|
||
|
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
|
||
|
traefiktls "github.com/traefik/traefik/v2/pkg/tls"
|
||
|
)
|
||
|
|
||
|
// SNICheck is an HTTP handler that checks whether the TLS configuration for the server name is the same as for the host header.
|
||
|
type SNICheck struct {
|
||
|
next http.Handler
|
||
|
tlsOptionsForHost map[string]string
|
||
|
}
|
||
|
|
||
|
// New creates a new SNICheck.
|
||
|
func New(tlsOptionsForHost map[string]string, next http.Handler) *SNICheck {
|
||
|
return &SNICheck{next: next, tlsOptionsForHost: tlsOptionsForHost}
|
||
|
}
|
||
|
|
||
|
func (s SNICheck) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||
|
if req.TLS == nil {
|
||
|
s.next.ServeHTTP(rw, req)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
host := getHost(req)
|
||
|
serverName := strings.TrimSpace(req.TLS.ServerName)
|
||
|
|
||
|
// Domain Fronting
|
||
|
if !strings.EqualFold(host, serverName) {
|
||
|
tlsOptionHeader := findTLSOptionName(s.tlsOptionsForHost, host, true)
|
||
|
tlsOptionSNI := findTLSOptionName(s.tlsOptionsForHost, serverName, false)
|
||
|
|
||
|
if tlsOptionHeader != tlsOptionSNI {
|
||
|
log.WithoutContext().
|
||
|
WithField("host", host).
|
||
|
WithField("req.Host", req.Host).
|
||
|
WithField("req.TLS.ServerName", req.TLS.ServerName).
|
||
|
Debugf("TLS options difference: SNI:%s, Header:%s", tlsOptionSNI, tlsOptionHeader)
|
||
|
http.Error(rw, http.StatusText(http.StatusMisdirectedRequest), http.StatusMisdirectedRequest)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
s.next.ServeHTTP(rw, req)
|
||
|
}
|
||
|
|
||
|
func getHost(req *http.Request) string {
|
||
|
h := requestdecorator.GetCNAMEFlatten(req.Context())
|
||
|
if h != "" {
|
||
|
return h
|
||
|
}
|
||
|
|
||
|
h = requestdecorator.GetCanonizedHost(req.Context())
|
||
|
if h != "" {
|
||
|
return h
|
||
|
}
|
||
|
|
||
|
host, _, err := net.SplitHostPort(req.Host)
|
||
|
if err != nil {
|
||
|
host = req.Host
|
||
|
}
|
||
|
|
||
|
return strings.TrimSpace(host)
|
||
|
}
|
||
|
|
||
|
func findTLSOptionName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
|
||
|
name := findTLSOptName(tlsOptionsForHost, host, fqdn)
|
||
|
if name != "" {
|
||
|
return name
|
||
|
}
|
||
|
|
||
|
name = findTLSOptName(tlsOptionsForHost, strings.ToLower(host), fqdn)
|
||
|
if name != "" {
|
||
|
return name
|
||
|
}
|
||
|
|
||
|
return traefiktls.DefaultTLSConfigName
|
||
|
}
|
||
|
|
||
|
func findTLSOptName(tlsOptionsForHost map[string]string, host string, fqdn bool) string {
|
||
|
if tlsOptions, ok := tlsOptionsForHost[host]; ok {
|
||
|
return tlsOptions
|
||
|
}
|
||
|
|
||
|
if !fqdn {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
if last := len(host) - 1; last >= 0 && host[last] == '.' {
|
||
|
if tlsOptions, ok := tlsOptionsForHost[host[:last]]; ok {
|
||
|
return tlsOptions
|
||
|
}
|
||
|
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
if tlsOptions, ok := tlsOptionsForHost[host+"."]; ok {
|
||
|
return tlsOptions
|
||
|
}
|
||
|
|
||
|
return ""
|
||
|
}
|