Refactor certificate domains matching func

This commit is contained in:
Romain 2022-07-12 16:16:08 +02:00 committed by GitHub
parent f07fcd3d54
commit ff2911d070
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -73,17 +73,17 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
if c == nil { if c == nil {
return nil return nil
} }
domainToCheck := strings.ToLower(strings.TrimSpace(clientHello.ServerName)) serverName := strings.ToLower(strings.TrimSpace(clientHello.ServerName))
if len(domainToCheck) == 0 { if len(serverName) == 0 {
// If no ServerName is provided, Check for local IP address matches // If no ServerName is provided, Check for local IP address matches
host, _, err := net.SplitHostPort(clientHello.Conn.LocalAddr().String()) host, _, err := net.SplitHostPort(clientHello.Conn.LocalAddr().String())
if err != nil { if err != nil {
log.Debugf("Could not split host/port: %v", err) log.WithoutContext().Debugf("Could not split host/port: %v", err)
} }
domainToCheck = strings.TrimSpace(host) serverName = strings.TrimSpace(host)
} }
if cert, ok := c.CertCache.Get(domainToCheck); ok { if cert, ok := c.CertCache.Get(serverName); ok {
return cert.(*tls.Certificate) return cert.(*tls.Certificate)
} }
@ -91,7 +91,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil { if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil {
for domains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) { for domains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) {
for _, certDomain := range strings.Split(domains, ",") { for _, certDomain := range strings.Split(domains, ",") {
if MatchDomain(domainToCheck, certDomain) { if matchDomain(serverName, certDomain) {
matchedCerts[certDomain] = cert matchedCerts[certDomain] = cert
} }
} }
@ -107,7 +107,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
sort.Strings(keys) sort.Strings(keys)
// cache best match // cache best match
c.CertCache.SetDefault(domainToCheck, matchedCerts[keys[len(keys)-1]]) c.CertCache.SetDefault(serverName, matchedCerts[keys[len(keys)-1]])
return matchedCerts[keys[len(keys)-1]] return matchedCerts[keys[len(keys)-1]]
} }
@ -121,9 +121,12 @@ func (c CertificateStore) ResetCache() {
} }
} }
// MatchDomain return true if a domain match the cert domain. // matchDomain returns whether the server name matches the cert domain.
func MatchDomain(domain, certDomain string) bool { // The server name, from TLS SNI, must not have trailing dots (https://datatracker.ietf.org/doc/html/rfc6066#section-3).
if domain == certDomain { // This is enforced by https://github.com/golang/go/blob/d3d7998756c33f69706488cade1cd2b9b10a4c7f/src/crypto/tls/handshake_messages.go#L423-L427.
func matchDomain(serverName, certDomain string) bool {
// TODO: assert equality after removing the trailing dots?
if serverName == certDomain {
return true return true
} }
@ -131,7 +134,7 @@ func MatchDomain(domain, certDomain string) bool {
certDomain = certDomain[:len(certDomain)-1] certDomain = certDomain[:len(certDomain)-1]
} }
labels := strings.Split(domain, ".") labels := strings.Split(serverName, ".")
for i := range labels { for i := range labels {
labels[i] = "*" labels[i] = "*"
candidate := strings.Join(labels, ".") candidate := strings.Join(labels, ".")