Avoid duplicated ACME resolution

This commit is contained in:
NicoMen 2018-08-20 09:40:03 +02:00 committed by Traefiker Bot
parent 60b4095c75
commit d81c4e6d1a
5 changed files with 126 additions and 7 deletions

View file

@ -11,6 +11,7 @@ import (
"net/http"
"reflect"
"strings"
"sync"
"time"
"github.com/BurntSushi/ty/fun"
@ -61,6 +62,8 @@ type ACME struct {
jobs *channels.InfiniteChannel
TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"`
dynamicCerts *safe.Safe
resolvingDomains map[string]struct{}
resolvingDomainsMutex sync.RWMutex
}
func (a *ACME) init() error {
@ -81,6 +84,10 @@ func (a *ACME) init() error {
a.defaultCertificate = cert
a.jobs = channels.NewInfiniteChannel()
// Init the currently resolved domain map
a.resolvingDomains = make(map[string]struct{})
return nil
}
@ -502,6 +509,10 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
if len(uncheckedDomains) == 0 {
return
}
a.addResolvingDomains(uncheckedDomains)
defer a.removeResolvingDomains(uncheckedDomains)
certificate, err := a.getDomainsCertificates(uncheckedDomains)
if err != nil {
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
@ -533,6 +544,24 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
}
}
func (a *ACME) addResolvingDomains(resolvingDomains []string) {
a.resolvingDomainsMutex.Lock()
defer a.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
a.resolvingDomains[domain] = struct{}{}
}
}
func (a *ACME) removeResolvingDomains(resolvingDomains []string) {
a.resolvingDomainsMutex.Lock()
defer a.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
delete(a.resolvingDomains, domain)
}
}
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
@ -568,6 +597,9 @@ func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Ce
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
a.resolvingDomainsMutex.RLock()
defer a.resolvingDomainsMutex.RUnlock()
log.Debugf("Looking for provided certificate to validate %s...", domains)
allCerts := make(map[string]*tls.Certificate)
@ -590,6 +622,13 @@ func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string
}
}
// Get currently resolved domains
for domain := range a.resolvingDomains {
if _, ok := allCerts[domain]; !ok {
allCerts[domain] = &tls.Certificate{}
}
}
// Get Configuration Domains
for i := 0; i < len(a.Domains); i++ {
allCerts[a.Domains[i].Main] = &tls.Certificate{}

View file

@ -331,9 +331,12 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
mm["*.containo.us"] = &tls.Certificate{}
mm["traefik.acme.io"] = &tls.Certificate{}
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}
dm := make(map[string]struct{})
dm["*.traefik.wtf"] = struct{}{}
domains := []string{"traefik.containo.us", "trae.containo.us"}
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}, resolvingDomains: dm}
domains := []string{"traefik.containo.us", "trae.containo.us", "foo.traefik.wtf"}
uncheckedDomains := a.getUncheckedDomains(domains, nil)
assert.Empty(t, uncheckedDomains)
domains = []string{"traefik.acme.io", "trae.acme.io"}
@ -351,6 +354,9 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
account := Account{DomainsCertificate: domainsCertificates}
uncheckedDomains = a.getUncheckedDomains(domains, &account)
assert.Empty(t, uncheckedDomains)
domains = []string{"traefik.containo.us", "trae.containo.us", "traefik.wtf"}
uncheckedDomains = a.getUncheckedDomains(domains, nil)
assert.Len(t, uncheckedDomains, 1)
}
func TestAcme_getProvidedCertificate(t *testing.T) {

View file

@ -50,7 +50,7 @@ start_boulder() {
# Script usage
show_usage() {
echo
echo "USAGE : manage_acme_docker_environment.sh [--start|--stop|--restart]"
echo "USAGE : manage_acme_docker_environment.sh [--dev|--start|--stop|--restart]"
echo
}

View file

@ -63,6 +63,8 @@ type Provider struct {
clientMutex sync.Mutex
configFromListenerChan chan types.Configuration
pool *safe.Pool
resolvingDomains map[string]struct{}
resolvingDomainsMutex sync.RWMutex
}
// Certificate is a struct which contains all data needed from an ACME certificate
@ -127,6 +129,9 @@ func (p *Provider) init() error {
return fmt.Errorf("unable to get ACME certificates : %v", err)
}
// Init the currently resolved domain map
p.resolvingDomains = make(map[string]struct{})
p.watchCertificate()
p.watchNewDomains()
@ -226,6 +231,9 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
return nil, nil
}
p.addResolvingDomains(uncheckedDomains)
defer p.removeResolvingDomains(uncheckedDomains)
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
client, err := p.getClient()
if err != nil {
@ -254,6 +262,24 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
return certificate, nil
}
func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
p.resolvingDomainsMutex.Lock()
defer p.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
delete(p.resolvingDomains, domain)
}
}
func (p *Provider) addResolvingDomains(resolvingDomains []string) {
p.resolvingDomainsMutex.Lock()
defer p.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
p.resolvingDomains[domain] = struct{}{}
}
}
func (p *Provider) getClient() (*acme.Client, error) {
p.clientMutex.Lock()
defer p.clientMutex.Unlock()
@ -503,6 +529,9 @@ func (p *Provider) AddRoutes(router *mux.Router) {
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
p.resolvingDomainsMutex.RLock()
defer p.resolvingDomainsMutex.RUnlock()
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
var allCerts []string
@ -523,6 +552,11 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati
allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ","))
}
// Get currently resolved domains
for domain := range p.resolvingDomains {
allCerts = append(allCerts, domain)
}
// Get Configuration Domains
if checkConfigurationDomains {
for i := 0; i < len(p.Domains); i++ {
@ -540,8 +574,9 @@ func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) [
uncheckedDomains = append(uncheckedDomains, domainToCheck)
}
}
if len(uncheckedDomains) == 0 {
log.Debugf("No ACME certificate to generate for domains %q.", domainsToCheck)
log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck)
} else {
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
}

View file

@ -26,6 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) {
desc string
dynamicCerts *safe.Safe
staticCerts map[string]*tls.Certificate
resolvingDomains map[string]struct{}
acmeCertificates []*Certificate
domains []string
expectedDomains []string
@ -138,17 +139,55 @@ func TestGetUncheckedCertificates(t *testing.T) {
},
expectedDomains: []string{"traefik.wtf"},
},
{
desc: "all domains already managed by ACME",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"traefik.wtf": {},
"foo.traefik.wtf": {},
},
expectedDomains: []string{},
},
{
desc: "one domain already managed by ACME",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"traefik.wtf": {},
},
expectedDomains: []string{"foo.traefik.wtf"},
},
{
desc: "wildcard domain already managed by ACME checks the domains",
domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"*.traefik.wtf": {},
},
expectedDomains: []string{},
},
{
desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked",
domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"},
resolvingDomains: map[string]struct{}{
"*.traefik.wtf": {},
"traefik.wtf": {},
},
expectedDomains: []string{"acme.wtf"},
},
}
for _, test := range testCases {
test := test
if test.resolvingDomains == nil {
test.resolvingDomains = make(map[string]struct{})
}
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
acmeProvider := Provider{
dynamicCerts: test.dynamicCerts,
staticCerts: test.staticCerts,
certificates: test.acmeCertificates,
dynamicCerts: test.dynamicCerts,
staticCerts: test.staticCerts,
certificates: test.acmeCertificates,
resolvingDomains: test.resolvingDomains,
}
domains := acmeProvider.getUncheckedDomains(test.domains, false)