diff --git a/acme/acme.go b/acme/acme.go index e31bbbdc3..8a06b823d 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -11,6 +11,7 @@ import ( "net/http" "reflect" "strings" + "sync" "time" "github.com/BurntSushi/ty/fun" @@ -61,6 +62,8 @@ type ACME struct { jobs *channels.InfiniteChannel TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"` dynamicCerts *safe.Safe + resolvingDomains map[string]struct{} + resolvingDomainsMutex sync.RWMutex } func (a *ACME) init() error { @@ -81,6 +84,10 @@ func (a *ACME) init() error { a.defaultCertificate = cert a.jobs = channels.NewInfiniteChannel() + + // Init the currently resolved domain map + a.resolvingDomains = make(map[string]struct{}) + return nil } @@ -502,6 +509,10 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { if len(uncheckedDomains) == 0 { return } + + a.addResolvingDomains(uncheckedDomains) + defer a.removeResolvingDomains(uncheckedDomains) + certificate, err := a.getDomainsCertificates(uncheckedDomains) if err != nil { log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err) @@ -533,6 +544,24 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { } } +func (a *ACME) addResolvingDomains(resolvingDomains []string) { + a.resolvingDomainsMutex.Lock() + defer a.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + a.resolvingDomains[domain] = struct{}{} + } +} + +func (a *ACME) removeResolvingDomains(resolvingDomains []string) { + a.resolvingDomainsMutex.Lock() + defer a.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + delete(a.resolvingDomains, domain) + } +} + // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate { @@ -568,6 +597,9 @@ func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Ce // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string { + a.resolvingDomainsMutex.RLock() + defer a.resolvingDomainsMutex.RUnlock() + log.Debugf("Looking for provided certificate to validate %s...", domains) allCerts := make(map[string]*tls.Certificate) @@ -590,6 +622,13 @@ func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string } } + // Get currently resolved domains + for domain := range a.resolvingDomains { + if _, ok := allCerts[domain]; !ok { + allCerts[domain] = &tls.Certificate{} + } + } + // Get Configuration Domains for i := 0; i < len(a.Domains); i++ { allCerts[a.Domains[i].Main] = &tls.Certificate{} diff --git a/acme/acme_test.go b/acme/acme_test.go index 9e3d2ace4..aadfa17b6 100644 --- a/acme/acme_test.go +++ b/acme/acme_test.go @@ -331,9 +331,12 @@ func TestAcme_getUncheckedCertificates(t *testing.T) { mm["*.containo.us"] = &tls.Certificate{} mm["traefik.acme.io"] = &tls.Certificate{} - a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}} + dm := make(map[string]struct{}) + dm["*.traefik.wtf"] = struct{}{} - domains := []string{"traefik.containo.us", "trae.containo.us"} + a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}, resolvingDomains: dm} + + domains := []string{"traefik.containo.us", "trae.containo.us", "foo.traefik.wtf"} uncheckedDomains := a.getUncheckedDomains(domains, nil) assert.Empty(t, uncheckedDomains) domains = []string{"traefik.acme.io", "trae.acme.io"} @@ -351,6 +354,9 @@ func TestAcme_getUncheckedCertificates(t *testing.T) { account := Account{DomainsCertificate: domainsCertificates} uncheckedDomains = a.getUncheckedDomains(domains, &account) assert.Empty(t, uncheckedDomains) + domains = []string{"traefik.containo.us", "trae.containo.us", "traefik.wtf"} + uncheckedDomains = a.getUncheckedDomains(domains, nil) + assert.Len(t, uncheckedDomains, 1) } func TestAcme_getProvidedCertificate(t *testing.T) { diff --git a/examples/acme/manage_acme_docker_environment.sh b/examples/acme/manage_acme_docker_environment.sh index a95483c9e..58cf73362 100755 --- a/examples/acme/manage_acme_docker_environment.sh +++ b/examples/acme/manage_acme_docker_environment.sh @@ -50,7 +50,7 @@ start_boulder() { # Script usage show_usage() { echo - echo "USAGE : manage_acme_docker_environment.sh [--start|--stop|--restart]" + echo "USAGE : manage_acme_docker_environment.sh [--dev|--start|--stop|--restart]" echo } diff --git a/provider/acme/provider.go b/provider/acme/provider.go index 958a77707..74adf62d4 100644 --- a/provider/acme/provider.go +++ b/provider/acme/provider.go @@ -63,6 +63,8 @@ type Provider struct { clientMutex sync.Mutex configFromListenerChan chan types.Configuration pool *safe.Pool + resolvingDomains map[string]struct{} + resolvingDomainsMutex sync.RWMutex } // Certificate is a struct which contains all data needed from an ACME certificate @@ -127,6 +129,9 @@ func (p *Provider) init() error { return fmt.Errorf("unable to get ACME certificates : %v", err) } + // Init the currently resolved domain map + p.resolvingDomains = make(map[string]struct{}) + p.watchCertificate() p.watchNewDomains() @@ -226,6 +231,9 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati return nil, nil } + p.addResolvingDomains(uncheckedDomains) + defer p.removeResolvingDomains(uncheckedDomains) + log.Debugf("Loading ACME certificates %+v...", uncheckedDomains) client, err := p.getClient() if err != nil { @@ -254,6 +262,24 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati return certificate, nil } +func (p *Provider) removeResolvingDomains(resolvingDomains []string) { + p.resolvingDomainsMutex.Lock() + defer p.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + delete(p.resolvingDomains, domain) + } +} + +func (p *Provider) addResolvingDomains(resolvingDomains []string) { + p.resolvingDomainsMutex.Lock() + defer p.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + p.resolvingDomains[domain] = struct{}{} + } +} + func (p *Provider) getClient() (*acme.Client, error) { p.clientMutex.Lock() defer p.clientMutex.Unlock() @@ -503,6 +529,9 @@ func (p *Provider) AddRoutes(router *mux.Router) { // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string { + p.resolvingDomainsMutex.RLock() + defer p.resolvingDomainsMutex.RUnlock() + log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck) var allCerts []string @@ -523,6 +552,11 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ",")) } + // Get currently resolved domains + for domain := range p.resolvingDomains { + allCerts = append(allCerts, domain) + } + // Get Configuration Domains if checkConfigurationDomains { for i := 0; i < len(p.Domains); i++ { @@ -540,8 +574,9 @@ func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) [ uncheckedDomains = append(uncheckedDomains, domainToCheck) } } + if len(uncheckedDomains) == 0 { - log.Debugf("No ACME certificate to generate for domains %q.", domainsToCheck) + log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck) } else { log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ",")) } diff --git a/provider/acme/provider_test.go b/provider/acme/provider_test.go index 2f6a3db96..5e79d60b8 100644 --- a/provider/acme/provider_test.go +++ b/provider/acme/provider_test.go @@ -26,6 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) { desc string dynamicCerts *safe.Safe staticCerts map[string]*tls.Certificate + resolvingDomains map[string]struct{} acmeCertificates []*Certificate domains []string expectedDomains []string @@ -138,17 +139,55 @@ func TestGetUncheckedCertificates(t *testing.T) { }, expectedDomains: []string{"traefik.wtf"}, }, + { + desc: "all domains already managed by ACME", + domains: []string{"traefik.wtf", "foo.traefik.wtf"}, + resolvingDomains: map[string]struct{}{ + "traefik.wtf": {}, + "foo.traefik.wtf": {}, + }, + expectedDomains: []string{}, + }, + { + desc: "one domain already managed by ACME", + domains: []string{"traefik.wtf", "foo.traefik.wtf"}, + resolvingDomains: map[string]struct{}{ + "traefik.wtf": {}, + }, + expectedDomains: []string{"foo.traefik.wtf"}, + }, + { + desc: "wildcard domain already managed by ACME checks the domains", + domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"}, + resolvingDomains: map[string]struct{}{ + "*.traefik.wtf": {}, + }, + expectedDomains: []string{}, + }, + { + desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked", + domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"}, + resolvingDomains: map[string]struct{}{ + "*.traefik.wtf": {}, + "traefik.wtf": {}, + }, + expectedDomains: []string{"acme.wtf"}, + }, } for _, test := range testCases { test := test + if test.resolvingDomains == nil { + test.resolvingDomains = make(map[string]struct{}) + } t.Run(test.desc, func(t *testing.T) { t.Parallel() acmeProvider := Provider{ - dynamicCerts: test.dynamicCerts, - staticCerts: test.staticCerts, - certificates: test.acmeCertificates, + dynamicCerts: test.dynamicCerts, + staticCerts: test.staticCerts, + certificates: test.acmeCertificates, + resolvingDomains: test.resolvingDomains, } domains := acmeProvider.getUncheckedDomains(test.domains, false)