Avoid duplicated ACME resolution
This commit is contained in:
parent
60b4095c75
commit
d81c4e6d1a
5 changed files with 126 additions and 7 deletions
39
acme/acme.go
39
acme/acme.go
|
@ -11,6 +11,7 @@ import (
|
|||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/BurntSushi/ty/fun"
|
||||
|
@ -61,6 +62,8 @@ type ACME struct {
|
|||
jobs *channels.InfiniteChannel
|
||||
TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"`
|
||||
dynamicCerts *safe.Safe
|
||||
resolvingDomains map[string]struct{}
|
||||
resolvingDomainsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (a *ACME) init() error {
|
||||
|
@ -81,6 +84,10 @@ func (a *ACME) init() error {
|
|||
a.defaultCertificate = cert
|
||||
|
||||
a.jobs = channels.NewInfiniteChannel()
|
||||
|
||||
// Init the currently resolved domain map
|
||||
a.resolvingDomains = make(map[string]struct{})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -502,6 +509,10 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
|
|||
if len(uncheckedDomains) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
a.addResolvingDomains(uncheckedDomains)
|
||||
defer a.removeResolvingDomains(uncheckedDomains)
|
||||
|
||||
certificate, err := a.getDomainsCertificates(uncheckedDomains)
|
||||
if err != nil {
|
||||
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
|
||||
|
@ -533,6 +544,24 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (a *ACME) addResolvingDomains(resolvingDomains []string) {
|
||||
a.resolvingDomainsMutex.Lock()
|
||||
defer a.resolvingDomainsMutex.Unlock()
|
||||
|
||||
for _, domain := range resolvingDomains {
|
||||
a.resolvingDomains[domain] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ACME) removeResolvingDomains(resolvingDomains []string) {
|
||||
a.resolvingDomainsMutex.Lock()
|
||||
defer a.resolvingDomainsMutex.Unlock()
|
||||
|
||||
for _, domain := range resolvingDomains {
|
||||
delete(a.resolvingDomains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
// Get provided certificate which check a domains list (Main and SANs)
|
||||
// from static and dynamic provided certificates
|
||||
func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
|
||||
|
@ -568,6 +597,9 @@ func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Ce
|
|||
// Get provided certificate which check a domains list (Main and SANs)
|
||||
// from static and dynamic provided certificates
|
||||
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
|
||||
a.resolvingDomainsMutex.RLock()
|
||||
defer a.resolvingDomainsMutex.RUnlock()
|
||||
|
||||
log.Debugf("Looking for provided certificate to validate %s...", domains)
|
||||
allCerts := make(map[string]*tls.Certificate)
|
||||
|
||||
|
@ -590,6 +622,13 @@ func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string
|
|||
}
|
||||
}
|
||||
|
||||
// Get currently resolved domains
|
||||
for domain := range a.resolvingDomains {
|
||||
if _, ok := allCerts[domain]; !ok {
|
||||
allCerts[domain] = &tls.Certificate{}
|
||||
}
|
||||
}
|
||||
|
||||
// Get Configuration Domains
|
||||
for i := 0; i < len(a.Domains); i++ {
|
||||
allCerts[a.Domains[i].Main] = &tls.Certificate{}
|
||||
|
|
|
@ -331,9 +331,12 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
|
|||
mm["*.containo.us"] = &tls.Certificate{}
|
||||
mm["traefik.acme.io"] = &tls.Certificate{}
|
||||
|
||||
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}
|
||||
dm := make(map[string]struct{})
|
||||
dm["*.traefik.wtf"] = struct{}{}
|
||||
|
||||
domains := []string{"traefik.containo.us", "trae.containo.us"}
|
||||
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}, resolvingDomains: dm}
|
||||
|
||||
domains := []string{"traefik.containo.us", "trae.containo.us", "foo.traefik.wtf"}
|
||||
uncheckedDomains := a.getUncheckedDomains(domains, nil)
|
||||
assert.Empty(t, uncheckedDomains)
|
||||
domains = []string{"traefik.acme.io", "trae.acme.io"}
|
||||
|
@ -351,6 +354,9 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
|
|||
account := Account{DomainsCertificate: domainsCertificates}
|
||||
uncheckedDomains = a.getUncheckedDomains(domains, &account)
|
||||
assert.Empty(t, uncheckedDomains)
|
||||
domains = []string{"traefik.containo.us", "trae.containo.us", "traefik.wtf"}
|
||||
uncheckedDomains = a.getUncheckedDomains(domains, nil)
|
||||
assert.Len(t, uncheckedDomains, 1)
|
||||
}
|
||||
|
||||
func TestAcme_getProvidedCertificate(t *testing.T) {
|
||||
|
|
|
@ -50,7 +50,7 @@ start_boulder() {
|
|||
# Script usage
|
||||
show_usage() {
|
||||
echo
|
||||
echo "USAGE : manage_acme_docker_environment.sh [--start|--stop|--restart]"
|
||||
echo "USAGE : manage_acme_docker_environment.sh [--dev|--start|--stop|--restart]"
|
||||
echo
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,8 @@ type Provider struct {
|
|||
clientMutex sync.Mutex
|
||||
configFromListenerChan chan types.Configuration
|
||||
pool *safe.Pool
|
||||
resolvingDomains map[string]struct{}
|
||||
resolvingDomainsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Certificate is a struct which contains all data needed from an ACME certificate
|
||||
|
@ -127,6 +129,9 @@ func (p *Provider) init() error {
|
|||
return fmt.Errorf("unable to get ACME certificates : %v", err)
|
||||
}
|
||||
|
||||
// Init the currently resolved domain map
|
||||
p.resolvingDomains = make(map[string]struct{})
|
||||
|
||||
p.watchCertificate()
|
||||
p.watchNewDomains()
|
||||
|
||||
|
@ -226,6 +231,9 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
p.addResolvingDomains(uncheckedDomains)
|
||||
defer p.removeResolvingDomains(uncheckedDomains)
|
||||
|
||||
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
|
||||
client, err := p.getClient()
|
||||
if err != nil {
|
||||
|
@ -254,6 +262,24 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
|
|||
return certificate, nil
|
||||
}
|
||||
|
||||
func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
|
||||
p.resolvingDomainsMutex.Lock()
|
||||
defer p.resolvingDomainsMutex.Unlock()
|
||||
|
||||
for _, domain := range resolvingDomains {
|
||||
delete(p.resolvingDomains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) addResolvingDomains(resolvingDomains []string) {
|
||||
p.resolvingDomainsMutex.Lock()
|
||||
defer p.resolvingDomainsMutex.Unlock()
|
||||
|
||||
for _, domain := range resolvingDomains {
|
||||
p.resolvingDomains[domain] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) getClient() (*acme.Client, error) {
|
||||
p.clientMutex.Lock()
|
||||
defer p.clientMutex.Unlock()
|
||||
|
@ -503,6 +529,9 @@ func (p *Provider) AddRoutes(router *mux.Router) {
|
|||
// Get provided certificate which check a domains list (Main and SANs)
|
||||
// from static and dynamic provided certificates
|
||||
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
|
||||
p.resolvingDomainsMutex.RLock()
|
||||
defer p.resolvingDomainsMutex.RUnlock()
|
||||
|
||||
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
|
||||
var allCerts []string
|
||||
|
||||
|
@ -523,6 +552,11 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati
|
|||
allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ","))
|
||||
}
|
||||
|
||||
// Get currently resolved domains
|
||||
for domain := range p.resolvingDomains {
|
||||
allCerts = append(allCerts, domain)
|
||||
}
|
||||
|
||||
// Get Configuration Domains
|
||||
if checkConfigurationDomains {
|
||||
for i := 0; i < len(p.Domains); i++ {
|
||||
|
@ -540,8 +574,9 @@ func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) [
|
|||
uncheckedDomains = append(uncheckedDomains, domainToCheck)
|
||||
}
|
||||
}
|
||||
|
||||
if len(uncheckedDomains) == 0 {
|
||||
log.Debugf("No ACME certificate to generate for domains %q.", domainsToCheck)
|
||||
log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck)
|
||||
} else {
|
||||
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) {
|
|||
desc string
|
||||
dynamicCerts *safe.Safe
|
||||
staticCerts map[string]*tls.Certificate
|
||||
resolvingDomains map[string]struct{}
|
||||
acmeCertificates []*Certificate
|
||||
domains []string
|
||||
expectedDomains []string
|
||||
|
@ -138,17 +139,55 @@ func TestGetUncheckedCertificates(t *testing.T) {
|
|||
},
|
||||
expectedDomains: []string{"traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "all domains already managed by ACME",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"traefik.wtf": {},
|
||||
"foo.traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{},
|
||||
},
|
||||
{
|
||||
desc: "one domain already managed by ACME",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{"foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "wildcard domain already managed by ACME checks the domains",
|
||||
domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"*.traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{},
|
||||
},
|
||||
{
|
||||
desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked",
|
||||
domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"*.traefik.wtf": {},
|
||||
"traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{"acme.wtf"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
if test.resolvingDomains == nil {
|
||||
test.resolvingDomains = make(map[string]struct{})
|
||||
}
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
acmeProvider := Provider{
|
||||
dynamicCerts: test.dynamicCerts,
|
||||
staticCerts: test.staticCerts,
|
||||
certificates: test.acmeCertificates,
|
||||
dynamicCerts: test.dynamicCerts,
|
||||
staticCerts: test.staticCerts,
|
||||
certificates: test.acmeCertificates,
|
||||
resolvingDomains: test.resolvingDomains,
|
||||
}
|
||||
|
||||
domains := acmeProvider.getUncheckedDomains(test.domains, false)
|
||||
|
|
Loading…
Add table
Reference in a new issue