diff --git a/acme/account.go b/acme/account.go index 2e4526d75..cced2c0a4 100644 --- a/acme/account.go +++ b/acme/account.go @@ -222,6 +222,24 @@ func (dc *DomainsCertificates) exists(domainToFind Domain) (*DomainsCertificate, return nil, false } +func (dc *DomainsCertificates) toDomainsMap() map[string]*tls.Certificate { + domainsCertificatesMap := make(map[string]*tls.Certificate) + for _, domainCertificate := range dc.Certs { + certKey := domainCertificate.Domains.Main + if domainCertificate.Domains.SANs != nil { + sort.Strings(domainCertificate.Domains.SANs) + for _, dnsName := range domainCertificate.Domains.SANs { + if dnsName != domainCertificate.Domains.Main { + certKey += fmt.Sprintf(",%s", dnsName) + } + } + + } + domainsCertificatesMap[certKey] = domainCertificate.tlsCert + } + return domainsCertificatesMap +} + // DomainsCertificate contains a certificate for multiple domains type DomainsCertificate struct { Domains Domain diff --git a/acme/acme.go b/acme/acme.go index 6dacdb90c..bffe03542 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -391,7 +391,7 @@ func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificat domain := types.CanonicalDomain(clientHello.ServerName) account := a.store.Get().(*Account) - if providedCertificate := a.getProvidedCertificate([]string{domain}); providedCertificate != nil { + if providedCertificate := a.getProvidedCertificate(domain); providedCertificate != nil { return providedCertificate, nil } @@ -625,11 +625,6 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { domains = fun.Map(types.CanonicalDomain, domains).([]string) - // Check provided certificates - if a.getProvidedCertificate(domains) != nil { - return - } - operation := func() error { if a.client == nil { return errors.New("ACME client still not built") @@ -647,32 +642,34 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { return } account := a.store.Get().(*Account) - var domain Domain - if len(domains) > 1 { - domain = Domain{Main: domains[0], SANs: domains[1:]} - } else { - domain = Domain{Main: domains[0]} - } - if _, exists := account.DomainsCertificate.exists(domain); exists { - // domain already exists + + // Check provided certificates + uncheckedDomains := a.getUncheckedDomains(domains, account) + if len(uncheckedDomains) == 0 { return } - certificate, err := a.getDomainsCertificates(domains) + certificate, err := a.getDomainsCertificates(uncheckedDomains) if err != nil { - log.Errorf("Error getting ACME certificates %+v : %v", domains, err) + log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err) return } - log.Debugf("Got certificate for domains %+v", domains) + log.Debugf("Got certificate for domains %+v", uncheckedDomains) transaction, object, err := a.store.Begin() if err != nil { - log.Errorf("Error creating transaction %+v : %v", domains, err) + log.Errorf("Error creating transaction %+v : %v", uncheckedDomains, err) return } + var domain Domain + if len(uncheckedDomains) > 1 { + domain = Domain{Main: uncheckedDomains[0], SANs: uncheckedDomains[1:]} + } else { + domain = Domain{Main: uncheckedDomains[0]} + } account = object.(*Account) _, err = account.DomainsCertificate.addCertificateForDomains(certificate, domain) if err != nil { - log.Errorf("Error adding ACME certificates %+v : %v", domains, err) + log.Errorf("Error adding ACME certificates %+v : %v", uncheckedDomains, err) return } if err = transaction.Commit(account); err != nil { @@ -684,36 +681,95 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates -func (a *ACME) getProvidedCertificate(domains []string) *tls.Certificate { +func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate { log.Debugf("Looking for provided certificate to validate %s...", domains) cert := searchProvidedCertificateForDomains(domains, a.TLSConfig.NameToCertificate) if cert == nil && a.dynamicCerts != nil && a.dynamicCerts.Get() != nil { cert = searchProvidedCertificateForDomains(domains, a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate)) } - log.Debugf("No provided certificate found for domains %s, get ACME certificate.", domains) + if cert == nil { + log.Debugf("No provided certificate found for domains %s, get ACME certificate.", domains) + } return cert } -func searchProvidedCertificateForDomains(domains []string, certs map[string]*tls.Certificate) *tls.Certificate { +func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Certificate) *tls.Certificate { // Use regex to test for provided certs that might have been added into TLSConfig - providedCertMatch := false - for k := range certs { - selector := "^" + strings.Replace(k, "*.", "[^\\.]*\\.?", -1) + "$" - for _, domainToCheck := range domains { - providedCertMatch, _ = regexp.MatchString(selector, domainToCheck) - if !providedCertMatch { + for certDomains := range certs { + domainCheck := false + for _, certDomain := range strings.Split(certDomains, ",") { + selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.?", -1) + "$" + domainCheck, _ = regexp.MatchString(selector, domain) + if domainCheck { break } } - if providedCertMatch { - log.Debugf("Got provided certificate for domains %s", domains) - return certs[k] - + if domainCheck { + log.Debugf("Domain %q checked by provided certificate %q", domain, certDomains) + return certs[certDomains] } } return nil } +// Get provided certificate which check a domains list (Main and SANs) +// from static and dynamic provided certificates +func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string { + log.Debugf("Looking for provided certificate to validate %s...", domains) + allCerts := make(map[string]*tls.Certificate) + + // Get static certificates + for domains, certificate := range a.TLSConfig.NameToCertificate { + allCerts[domains] = certificate + } + + // Get dynamic certificates + if a.dynamicCerts != nil && a.dynamicCerts.Get() != nil { + for domains, certificate := range a.dynamicCerts.Get().(*traefikTls.DomainsCertificates).Get().(map[string]*tls.Certificate) { + allCerts[domains] = certificate + } + } + + // Get ACME certificates + if account != nil { + for domains, certificate := range account.DomainsCertificate.toDomainsMap() { + allCerts[domains] = certificate + } + } + + return searchUncheckedDomains(domains, allCerts) +} + +func searchUncheckedDomains(domains []string, certs map[string]*tls.Certificate) []string { + uncheckedDomains := []string{} + for _, domainToCheck := range domains { + domainCheck := false + for certDomains := range certs { + domainCheck = false + for _, certDomain := range strings.Split(certDomains, ",") { + // Use regex to test for provided certs that might have been added into TLSConfig + selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.?", -1) + "$" + domainCheck, _ = regexp.MatchString(selector, domainToCheck) + if domainCheck { + break + } + } + if domainCheck { + break + } + } + if !domainCheck { + uncheckedDomains = append(uncheckedDomains, domainToCheck) + } + } + if len(uncheckedDomains) == 0 { + log.Debugf("No ACME certificate to generate for domains %q.", domains) + } else { + log.Debugf("Domains %q need ACME certificates generation for domains %q.", domains, strings.Join(uncheckedDomains, ",")) + } + return uncheckedDomains +} + func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) { domains = fun.Map(types.CanonicalDomain, domains).([]string) log.Debugf("Loading ACME certificates %s...", domains) diff --git a/acme/acme_test.go b/acme/acme_test.go index 3498e7b9f..db828e708 100644 --- a/acme/acme_test.go +++ b/acme/acme_test.go @@ -281,7 +281,7 @@ cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`) } } -func TestAcme_getProvidedCertificate(t *testing.T) { +func TestAcme_getUncheckedCertificates(t *testing.T) { mm := make(map[string]*tls.Certificate) mm["*.containo.us"] = &tls.Certificate{} mm["traefik.acme.io"] = &tls.Certificate{} @@ -289,9 +289,36 @@ func TestAcme_getProvidedCertificate(t *testing.T) { a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}} domains := []string{"traefik.containo.us", "trae.containo.us"} - certificate := a.getProvidedCertificate(domains) - assert.NotNil(t, certificate) + uncheckedDomains := a.getUncheckedDomains(domains, nil) + assert.Empty(t, uncheckedDomains) domains = []string{"traefik.acme.io", "trae.acme.io"} - certificate = a.getProvidedCertificate(domains) + uncheckedDomains = a.getUncheckedDomains(domains, nil) + assert.Len(t, uncheckedDomains, 1) + domainsCertificates := DomainsCertificates{Certs: []*DomainsCertificate{ + { + tlsCert: &tls.Certificate{}, + Domains: Domain{ + Main: "*.acme.wtf", + SANs: []string{"trae.acme.io"}, + }, + }, + }} + account := Account{DomainsCertificate: domainsCertificates} + uncheckedDomains = a.getUncheckedDomains(domains, &account) + assert.Empty(t, uncheckedDomains) +} + +func TestAcme_getProvidedCertificate(t *testing.T) { + mm := make(map[string]*tls.Certificate) + mm["*.containo.us"] = &tls.Certificate{} + mm["traefik.acme.io"] = &tls.Certificate{} + + a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}} + + domain := "traefik.containo.us" + certificate := a.getProvidedCertificate(domain) + assert.NotNil(t, certificate) + domain = "trae.acme.io" + certificate = a.getProvidedCertificate(domain) assert.Nil(t, certificate) }