Refactor certificate domains matching func
This commit is contained in:
parent
f07fcd3d54
commit
ff2911d070
1 changed files with 14 additions and 11 deletions
|
@ -73,17 +73,17 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
domainToCheck := strings.ToLower(strings.TrimSpace(clientHello.ServerName))
|
serverName := strings.ToLower(strings.TrimSpace(clientHello.ServerName))
|
||||||
if len(domainToCheck) == 0 {
|
if len(serverName) == 0 {
|
||||||
// If no ServerName is provided, Check for local IP address matches
|
// If no ServerName is provided, Check for local IP address matches
|
||||||
host, _, err := net.SplitHostPort(clientHello.Conn.LocalAddr().String())
|
host, _, err := net.SplitHostPort(clientHello.Conn.LocalAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Could not split host/port: %v", err)
|
log.WithoutContext().Debugf("Could not split host/port: %v", err)
|
||||||
}
|
}
|
||||||
domainToCheck = strings.TrimSpace(host)
|
serverName = strings.TrimSpace(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cert, ok := c.CertCache.Get(domainToCheck); ok {
|
if cert, ok := c.CertCache.Get(serverName); ok {
|
||||||
return cert.(*tls.Certificate)
|
return cert.(*tls.Certificate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
|
||||||
if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil {
|
if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil {
|
||||||
for domains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) {
|
for domains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) {
|
||||||
for _, certDomain := range strings.Split(domains, ",") {
|
for _, certDomain := range strings.Split(domains, ",") {
|
||||||
if MatchDomain(domainToCheck, certDomain) {
|
if matchDomain(serverName, certDomain) {
|
||||||
matchedCerts[certDomain] = cert
|
matchedCerts[certDomain] = cert
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
// cache best match
|
// cache best match
|
||||||
c.CertCache.SetDefault(domainToCheck, matchedCerts[keys[len(keys)-1]])
|
c.CertCache.SetDefault(serverName, matchedCerts[keys[len(keys)-1]])
|
||||||
return matchedCerts[keys[len(keys)-1]]
|
return matchedCerts[keys[len(keys)-1]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,9 +121,12 @@ func (c CertificateStore) ResetCache() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MatchDomain return true if a domain match the cert domain.
|
// matchDomain returns whether the server name matches the cert domain.
|
||||||
func MatchDomain(domain, certDomain string) bool {
|
// The server name, from TLS SNI, must not have trailing dots (https://datatracker.ietf.org/doc/html/rfc6066#section-3).
|
||||||
if domain == certDomain {
|
// This is enforced by https://github.com/golang/go/blob/d3d7998756c33f69706488cade1cd2b9b10a4c7f/src/crypto/tls/handshake_messages.go#L423-L427.
|
||||||
|
func matchDomain(serverName, certDomain string) bool {
|
||||||
|
// TODO: assert equality after removing the trailing dots?
|
||||||
|
if serverName == certDomain {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,7 +134,7 @@ func MatchDomain(domain, certDomain string) bool {
|
||||||
certDomain = certDomain[:len(certDomain)-1]
|
certDomain = certDomain[:len(certDomain)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
labels := strings.Split(domain, ".")
|
labels := strings.Split(serverName, ".")
|
||||||
for i := range labels {
|
for i := range labels {
|
||||||
labels[i] = "*"
|
labels[i] = "*"
|
||||||
candidate := strings.Join(labels, ".")
|
candidate := strings.Join(labels, ".")
|
||||||
|
|
Loading…
Reference in a new issue