From ff2911d070c8d8fc7a5de3d4618f7c2e2540b049 Mon Sep 17 00:00:00 2001 From: Romain Date: Tue, 12 Jul 2022 16:16:08 +0200 Subject: [PATCH] Refactor certificate domains matching func --- pkg/tls/certificate_store.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pkg/tls/certificate_store.go b/pkg/tls/certificate_store.go index 36922f1ce..8017c981a 100644 --- a/pkg/tls/certificate_store.go +++ b/pkg/tls/certificate_store.go @@ -73,17 +73,17 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo) if c == nil { return nil } - domainToCheck := strings.ToLower(strings.TrimSpace(clientHello.ServerName)) - if len(domainToCheck) == 0 { + serverName := strings.ToLower(strings.TrimSpace(clientHello.ServerName)) + if len(serverName) == 0 { // If no ServerName is provided, Check for local IP address matches host, _, err := net.SplitHostPort(clientHello.Conn.LocalAddr().String()) if err != nil { - log.Debugf("Could not split host/port: %v", err) + log.WithoutContext().Debugf("Could not split host/port: %v", err) } - domainToCheck = strings.TrimSpace(host) + serverName = strings.TrimSpace(host) } - if cert, ok := c.CertCache.Get(domainToCheck); ok { + if cert, ok := c.CertCache.Get(serverName); ok { return cert.(*tls.Certificate) } @@ -91,7 +91,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo) if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil { for domains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) { for _, certDomain := range strings.Split(domains, ",") { - if MatchDomain(domainToCheck, certDomain) { + if matchDomain(serverName, certDomain) { matchedCerts[certDomain] = cert } } @@ -107,7 +107,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo) sort.Strings(keys) // cache best match - c.CertCache.SetDefault(domainToCheck, matchedCerts[keys[len(keys)-1]]) + c.CertCache.SetDefault(serverName, matchedCerts[keys[len(keys)-1]]) return matchedCerts[keys[len(keys)-1]] } @@ -121,9 +121,12 @@ func (c CertificateStore) ResetCache() { } } -// MatchDomain return true if a domain match the cert domain. -func MatchDomain(domain, certDomain string) bool { - if domain == certDomain { +// matchDomain returns whether the server name matches the cert domain. +// The server name, from TLS SNI, must not have trailing dots (https://datatracker.ietf.org/doc/html/rfc6066#section-3). +// This is enforced by https://github.com/golang/go/blob/d3d7998756c33f69706488cade1cd2b9b10a4c7f/src/crypto/tls/handshake_messages.go#L423-L427. +func matchDomain(serverName, certDomain string) bool { + // TODO: assert equality after removing the trailing dots? + if serverName == certDomain { return true } @@ -131,7 +134,7 @@ func MatchDomain(domain, certDomain string) bool { certDomain = certDomain[:len(certDomain)-1] } - labels := strings.Split(domain, ".") + labels := strings.Split(serverName, ".") for i := range labels { labels[i] = "*" candidate := strings.Join(labels, ".")