Check all the C/N and SANs of provided certificates before to generat…

This commit is contained in:
NicoMen 2018-02-26 11:38:03 +01:00 committed by Traefiker Bot
parent 700b7a1b51
commit db483e9d34
3 changed files with 137 additions and 36 deletions

View file

@ -222,6 +222,24 @@ func (dc *DomainsCertificates) exists(domainToFind Domain) (*DomainsCertificate,
return nil, false return nil, false
} }
func (dc *DomainsCertificates) toDomainsMap() map[string]*tls.Certificate {
domainsCertificatesMap := make(map[string]*tls.Certificate)
for _, domainCertificate := range dc.Certs {
certKey := domainCertificate.Domains.Main
if domainCertificate.Domains.SANs != nil {
sort.Strings(domainCertificate.Domains.SANs)
for _, dnsName := range domainCertificate.Domains.SANs {
if dnsName != domainCertificate.Domains.Main {
certKey += fmt.Sprintf(",%s", dnsName)
}
}
}
domainsCertificatesMap[certKey] = domainCertificate.tlsCert
}
return domainsCertificatesMap
}
// DomainsCertificate contains a certificate for multiple domains // DomainsCertificate contains a certificate for multiple domains
type DomainsCertificate struct { type DomainsCertificate struct {
Domains Domain Domains Domain

View file

@ -391,7 +391,7 @@ func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificat
domain := types.CanonicalDomain(clientHello.ServerName) domain := types.CanonicalDomain(clientHello.ServerName)
account := a.store.Get().(*Account) account := a.store.Get().(*Account)
if providedCertificate := a.getProvidedCertificate([]string{domain}); providedCertificate != nil { if providedCertificate := a.getProvidedCertificate(domain); providedCertificate != nil {
return providedCertificate, nil return providedCertificate, nil
} }
@ -625,11 +625,6 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
domains = fun.Map(types.CanonicalDomain, domains).([]string) domains = fun.Map(types.CanonicalDomain, domains).([]string)
// Check provided certificates
if a.getProvidedCertificate(domains) != nil {
return
}
operation := func() error { operation := func() error {
if a.client == nil { if a.client == nil {
return errors.New("ACME client still not built") return errors.New("ACME client still not built")
@ -647,32 +642,34 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
return return
} }
account := a.store.Get().(*Account) account := a.store.Get().(*Account)
var domain Domain
if len(domains) > 1 { // Check provided certificates
domain = Domain{Main: domains[0], SANs: domains[1:]} uncheckedDomains := a.getUncheckedDomains(domains, account)
} else { if len(uncheckedDomains) == 0 {
domain = Domain{Main: domains[0]}
}
if _, exists := account.DomainsCertificate.exists(domain); exists {
// domain already exists
return return
} }
certificate, err := a.getDomainsCertificates(domains) certificate, err := a.getDomainsCertificates(uncheckedDomains)
if err != nil { if err != nil {
log.Errorf("Error getting ACME certificates %+v : %v", domains, err) log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
return return
} }
log.Debugf("Got certificate for domains %+v", domains) log.Debugf("Got certificate for domains %+v", uncheckedDomains)
transaction, object, err := a.store.Begin() transaction, object, err := a.store.Begin()
if err != nil { if err != nil {
log.Errorf("Error creating transaction %+v : %v", domains, err) log.Errorf("Error creating transaction %+v : %v", uncheckedDomains, err)
return return
} }
var domain Domain
if len(uncheckedDomains) > 1 {
domain = Domain{Main: uncheckedDomains[0], SANs: uncheckedDomains[1:]}
} else {
domain = Domain{Main: uncheckedDomains[0]}
}
account = object.(*Account) account = object.(*Account)
_, err = account.DomainsCertificate.addCertificateForDomains(certificate, domain) _, err = account.DomainsCertificate.addCertificateForDomains(certificate, domain)
if err != nil { if err != nil {
log.Errorf("Error adding ACME certificates %+v : %v", domains, err) log.Errorf("Error adding ACME certificates %+v : %v", uncheckedDomains, err)
return return
} }
if err = transaction.Commit(account); err != nil { if err = transaction.Commit(account); err != nil {
@ -684,36 +681,95 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
// Get provided certificate which check a domains list (Main and SANs) // Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates // from static and dynamic provided certificates
func (a *ACME) getProvidedCertificate(domains []string) *tls.Certificate { func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
log.Debugf("Looking for provided certificate to validate %s...", domains) log.Debugf("Looking for provided certificate to validate %s...", domains)
cert := searchProvidedCertificateForDomains(domains, a.TLSConfig.NameToCertificate) cert := searchProvidedCertificateForDomains(domains, a.TLSConfig.NameToCertificate)
if cert == nil && a.dynamicCerts != nil && a.dynamicCerts.Get() != nil { if cert == nil && a.dynamicCerts != nil && a.dynamicCerts.Get() != nil {
cert = searchProvidedCertificateForDomains(domains, a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate)) cert = searchProvidedCertificateForDomains(domains, a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate))
} }
if cert == nil {
log.Debugf("No provided certificate found for domains %s, get ACME certificate.", domains) log.Debugf("No provided certificate found for domains %s, get ACME certificate.", domains)
}
return cert return cert
} }
func searchProvidedCertificateForDomains(domains []string, certs map[string]*tls.Certificate) *tls.Certificate { func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Certificate) *tls.Certificate {
// Use regex to test for provided certs that might have been added into TLSConfig // Use regex to test for provided certs that might have been added into TLSConfig
providedCertMatch := false for certDomains := range certs {
for k := range certs { domainCheck := false
selector := "^" + strings.Replace(k, "*.", "[^\\.]*\\.?", -1) + "$" for _, certDomain := range strings.Split(certDomains, ",") {
for _, domainToCheck := range domains { selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.?", -1) + "$"
providedCertMatch, _ = regexp.MatchString(selector, domainToCheck) domainCheck, _ = regexp.MatchString(selector, domain)
if !providedCertMatch { if domainCheck {
break break
} }
} }
if providedCertMatch { if domainCheck {
log.Debugf("Got provided certificate for domains %s", domains) log.Debugf("Domain %q checked by provided certificate %q", domain, certDomains)
return certs[k] return certs[certDomains]
} }
} }
return nil return nil
} }
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
log.Debugf("Looking for provided certificate to validate %s...", domains)
allCerts := make(map[string]*tls.Certificate)
// Get static certificates
for domains, certificate := range a.TLSConfig.NameToCertificate {
allCerts[domains] = certificate
}
// Get dynamic certificates
if a.dynamicCerts != nil && a.dynamicCerts.Get() != nil {
for domains, certificate := range a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate) {
allCerts[domains] = certificate
}
}
// Get ACME certificates
if account != nil {
for domains, certificate := range account.DomainsCertificate.toDomainsMap() {
allCerts[domains] = certificate
}
}
return searchUncheckedDomains(domains, allCerts)
}
func searchUncheckedDomains(domains []string, certs map[string]*tls.Certificate) []string {
uncheckedDomains := []string{}
for _, domainToCheck := range domains {
domainCheck := false
for certDomains := range certs {
domainCheck = false
for _, certDomain := range strings.Split(certDomains, ",") {
// Use regex to test for provided certs that might have been added into TLSConfig
selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.?", -1) + "$"
domainCheck, _ = regexp.MatchString(selector, domainToCheck)
if domainCheck {
break
}
}
if domainCheck {
break
}
}
if !domainCheck {
uncheckedDomains = append(uncheckedDomains, domainToCheck)
}
}
if len(uncheckedDomains) == 0 {
log.Debugf("No ACME certificate to generate for domains %q.", domains)
} else {
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domains, strings.Join(uncheckedDomains, ","))
}
return uncheckedDomains
}
func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) { func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) {
domains = fun.Map(types.CanonicalDomain, domains).([]string) domains = fun.Map(types.CanonicalDomain, domains).([]string)
log.Debugf("Loading ACME certificates %s...", domains) log.Debugf("Loading ACME certificates %s...", domains)

View file

@ -281,7 +281,7 @@ cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`)
} }
} }
func TestAcme_getProvidedCertificate(t *testing.T) { func TestAcme_getUncheckedCertificates(t *testing.T) {
mm := make(map[string]*tls.Certificate) mm := make(map[string]*tls.Certificate)
mm["*.containo.us"] = &tls.Certificate{} mm["*.containo.us"] = &tls.Certificate{}
mm["traefik.acme.io"] = &tls.Certificate{} mm["traefik.acme.io"] = &tls.Certificate{}
@ -289,9 +289,36 @@ func TestAcme_getProvidedCertificate(t *testing.T) {
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}} a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}
domains := []string{"traefik.containo.us", "trae.containo.us"} domains := []string{"traefik.containo.us", "trae.containo.us"}
certificate := a.getProvidedCertificate(domains) uncheckedDomains := a.getUncheckedDomains(domains, nil)
assert.NotNil(t, certificate) assert.Empty(t, uncheckedDomains)
domains = []string{"traefik.acme.io", "trae.acme.io"} domains = []string{"traefik.acme.io", "trae.acme.io"}
certificate = a.getProvidedCertificate(domains) uncheckedDomains = a.getUncheckedDomains(domains, nil)
assert.Len(t, uncheckedDomains, 1)
domainsCertificates := DomainsCertificates{Certs: []*DomainsCertificate{
{
tlsCert: &tls.Certificate{},
Domains: Domain{
Main: "*.acme.wtf",
SANs: []string{"trae.acme.io"},
},
},
}}
account := Account{DomainsCertificate: domainsCertificates}
uncheckedDomains = a.getUncheckedDomains(domains, &account)
assert.Empty(t, uncheckedDomains)
}
func TestAcme_getProvidedCertificate(t *testing.T) {
mm := make(map[string]*tls.Certificate)
mm["*.containo.us"] = &tls.Certificate{}
mm["traefik.acme.io"] = &tls.Certificate{}
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}
domain := "traefik.containo.us"
certificate := a.getProvidedCertificate(domain)
assert.NotNil(t, certificate)
domain = "trae.acme.io"
certificate = a.getProvidedCertificate(domain)
assert.Nil(t, certificate) assert.Nil(t, certificate)
} }