Avoid duplicated ACME resolution
This commit is contained in:
parent
60b4095c75
commit
d81c4e6d1a
5 changed files with 126 additions and 7 deletions
39
acme/acme.go
39
acme/acme.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/BurntSushi/ty/fun"
|
"github.com/BurntSushi/ty/fun"
|
||||||
|
@ -61,6 +62,8 @@ type ACME struct {
|
||||||
jobs *channels.InfiniteChannel
|
jobs *channels.InfiniteChannel
|
||||||
TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"`
|
TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"`
|
||||||
dynamicCerts *safe.Safe
|
dynamicCerts *safe.Safe
|
||||||
|
resolvingDomains map[string]struct{}
|
||||||
|
resolvingDomainsMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ACME) init() error {
|
func (a *ACME) init() error {
|
||||||
|
@ -81,6 +84,10 @@ func (a *ACME) init() error {
|
||||||
a.defaultCertificate = cert
|
a.defaultCertificate = cert
|
||||||
|
|
||||||
a.jobs = channels.NewInfiniteChannel()
|
a.jobs = channels.NewInfiniteChannel()
|
||||||
|
|
||||||
|
// Init the currently resolved domain map
|
||||||
|
a.resolvingDomains = make(map[string]struct{})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -502,6 +509,10 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
|
||||||
if len(uncheckedDomains) == 0 {
|
if len(uncheckedDomains) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
a.addResolvingDomains(uncheckedDomains)
|
||||||
|
defer a.removeResolvingDomains(uncheckedDomains)
|
||||||
|
|
||||||
certificate, err := a.getDomainsCertificates(uncheckedDomains)
|
certificate, err := a.getDomainsCertificates(uncheckedDomains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
|
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
|
||||||
|
@ -533,6 +544,24 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ACME) addResolvingDomains(resolvingDomains []string) {
|
||||||
|
a.resolvingDomainsMutex.Lock()
|
||||||
|
defer a.resolvingDomainsMutex.Unlock()
|
||||||
|
|
||||||
|
for _, domain := range resolvingDomains {
|
||||||
|
a.resolvingDomains[domain] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ACME) removeResolvingDomains(resolvingDomains []string) {
|
||||||
|
a.resolvingDomainsMutex.Lock()
|
||||||
|
defer a.resolvingDomainsMutex.Unlock()
|
||||||
|
|
||||||
|
for _, domain := range resolvingDomains {
|
||||||
|
delete(a.resolvingDomains, domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get provided certificate which check a domains list (Main and SANs)
|
// Get provided certificate which check a domains list (Main and SANs)
|
||||||
// from static and dynamic provided certificates
|
// from static and dynamic provided certificates
|
||||||
func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
|
func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
|
||||||
|
@ -568,6 +597,9 @@ func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Ce
|
||||||
// Get provided certificate which check a domains list (Main and SANs)
|
// Get provided certificate which check a domains list (Main and SANs)
|
||||||
// from static and dynamic provided certificates
|
// from static and dynamic provided certificates
|
||||||
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
|
func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string {
|
||||||
|
a.resolvingDomainsMutex.RLock()
|
||||||
|
defer a.resolvingDomainsMutex.RUnlock()
|
||||||
|
|
||||||
log.Debugf("Looking for provided certificate to validate %s...", domains)
|
log.Debugf("Looking for provided certificate to validate %s...", domains)
|
||||||
allCerts := make(map[string]*tls.Certificate)
|
allCerts := make(map[string]*tls.Certificate)
|
||||||
|
|
||||||
|
@ -590,6 +622,13 @@ func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get currently resolved domains
|
||||||
|
for domain := range a.resolvingDomains {
|
||||||
|
if _, ok := allCerts[domain]; !ok {
|
||||||
|
allCerts[domain] = &tls.Certificate{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get Configuration Domains
|
// Get Configuration Domains
|
||||||
for i := 0; i < len(a.Domains); i++ {
|
for i := 0; i < len(a.Domains); i++ {
|
||||||
allCerts[a.Domains[i].Main] = &tls.Certificate{}
|
allCerts[a.Domains[i].Main] = &tls.Certificate{}
|
||||||
|
|
|
@ -331,9 +331,12 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
|
||||||
mm["*.containo.us"] = &tls.Certificate{}
|
mm["*.containo.us"] = &tls.Certificate{}
|
||||||
mm["traefik.acme.io"] = &tls.Certificate{}
|
mm["traefik.acme.io"] = &tls.Certificate{}
|
||||||
|
|
||||||
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}}
|
dm := make(map[string]struct{})
|
||||||
|
dm["*.traefik.wtf"] = struct{}{}
|
||||||
|
|
||||||
domains := []string{"traefik.containo.us", "trae.containo.us"}
|
a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}, resolvingDomains: dm}
|
||||||
|
|
||||||
|
domains := []string{"traefik.containo.us", "trae.containo.us", "foo.traefik.wtf"}
|
||||||
uncheckedDomains := a.getUncheckedDomains(domains, nil)
|
uncheckedDomains := a.getUncheckedDomains(domains, nil)
|
||||||
assert.Empty(t, uncheckedDomains)
|
assert.Empty(t, uncheckedDomains)
|
||||||
domains = []string{"traefik.acme.io", "trae.acme.io"}
|
domains = []string{"traefik.acme.io", "trae.acme.io"}
|
||||||
|
@ -351,6 +354,9 @@ func TestAcme_getUncheckedCertificates(t *testing.T) {
|
||||||
account := Account{DomainsCertificate: domainsCertificates}
|
account := Account{DomainsCertificate: domainsCertificates}
|
||||||
uncheckedDomains = a.getUncheckedDomains(domains, &account)
|
uncheckedDomains = a.getUncheckedDomains(domains, &account)
|
||||||
assert.Empty(t, uncheckedDomains)
|
assert.Empty(t, uncheckedDomains)
|
||||||
|
domains = []string{"traefik.containo.us", "trae.containo.us", "traefik.wtf"}
|
||||||
|
uncheckedDomains = a.getUncheckedDomains(domains, nil)
|
||||||
|
assert.Len(t, uncheckedDomains, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAcme_getProvidedCertificate(t *testing.T) {
|
func TestAcme_getProvidedCertificate(t *testing.T) {
|
||||||
|
|
|
@ -50,7 +50,7 @@ start_boulder() {
|
||||||
# Script usage
|
# Script usage
|
||||||
show_usage() {
|
show_usage() {
|
||||||
echo
|
echo
|
||||||
echo "USAGE : manage_acme_docker_environment.sh [--start|--stop|--restart]"
|
echo "USAGE : manage_acme_docker_environment.sh [--dev|--start|--stop|--restart]"
|
||||||
echo
|
echo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,8 @@ type Provider struct {
|
||||||
clientMutex sync.Mutex
|
clientMutex sync.Mutex
|
||||||
configFromListenerChan chan types.Configuration
|
configFromListenerChan chan types.Configuration
|
||||||
pool *safe.Pool
|
pool *safe.Pool
|
||||||
|
resolvingDomains map[string]struct{}
|
||||||
|
resolvingDomainsMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate is a struct which contains all data needed from an ACME certificate
|
// Certificate is a struct which contains all data needed from an ACME certificate
|
||||||
|
@ -127,6 +129,9 @@ func (p *Provider) init() error {
|
||||||
return fmt.Errorf("unable to get ACME certificates : %v", err)
|
return fmt.Errorf("unable to get ACME certificates : %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init the currently resolved domain map
|
||||||
|
p.resolvingDomains = make(map[string]struct{})
|
||||||
|
|
||||||
p.watchCertificate()
|
p.watchCertificate()
|
||||||
p.watchNewDomains()
|
p.watchNewDomains()
|
||||||
|
|
||||||
|
@ -226,6 +231,9 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.addResolvingDomains(uncheckedDomains)
|
||||||
|
defer p.removeResolvingDomains(uncheckedDomains)
|
||||||
|
|
||||||
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
|
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
|
||||||
client, err := p.getClient()
|
client, err := p.getClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -254,6 +262,24 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati
|
||||||
return certificate, nil
|
return certificate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
|
||||||
|
p.resolvingDomainsMutex.Lock()
|
||||||
|
defer p.resolvingDomainsMutex.Unlock()
|
||||||
|
|
||||||
|
for _, domain := range resolvingDomains {
|
||||||
|
delete(p.resolvingDomains, domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) addResolvingDomains(resolvingDomains []string) {
|
||||||
|
p.resolvingDomainsMutex.Lock()
|
||||||
|
defer p.resolvingDomainsMutex.Unlock()
|
||||||
|
|
||||||
|
for _, domain := range resolvingDomains {
|
||||||
|
p.resolvingDomains[domain] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Provider) getClient() (*acme.Client, error) {
|
func (p *Provider) getClient() (*acme.Client, error) {
|
||||||
p.clientMutex.Lock()
|
p.clientMutex.Lock()
|
||||||
defer p.clientMutex.Unlock()
|
defer p.clientMutex.Unlock()
|
||||||
|
@ -503,6 +529,9 @@ func (p *Provider) AddRoutes(router *mux.Router) {
|
||||||
// Get provided certificate which check a domains list (Main and SANs)
|
// Get provided certificate which check a domains list (Main and SANs)
|
||||||
// from static and dynamic provided certificates
|
// from static and dynamic provided certificates
|
||||||
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
|
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
|
||||||
|
p.resolvingDomainsMutex.RLock()
|
||||||
|
defer p.resolvingDomainsMutex.RUnlock()
|
||||||
|
|
||||||
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
|
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
|
||||||
var allCerts []string
|
var allCerts []string
|
||||||
|
|
||||||
|
@ -523,6 +552,11 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati
|
||||||
allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ","))
|
allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ","))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get currently resolved domains
|
||||||
|
for domain := range p.resolvingDomains {
|
||||||
|
allCerts = append(allCerts, domain)
|
||||||
|
}
|
||||||
|
|
||||||
// Get Configuration Domains
|
// Get Configuration Domains
|
||||||
if checkConfigurationDomains {
|
if checkConfigurationDomains {
|
||||||
for i := 0; i < len(p.Domains); i++ {
|
for i := 0; i < len(p.Domains); i++ {
|
||||||
|
@ -540,8 +574,9 @@ func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) [
|
||||||
uncheckedDomains = append(uncheckedDomains, domainToCheck)
|
uncheckedDomains = append(uncheckedDomains, domainToCheck)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(uncheckedDomains) == 0 {
|
if len(uncheckedDomains) == 0 {
|
||||||
log.Debugf("No ACME certificate to generate for domains %q.", domainsToCheck)
|
log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
|
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) {
|
||||||
desc string
|
desc string
|
||||||
dynamicCerts *safe.Safe
|
dynamicCerts *safe.Safe
|
||||||
staticCerts map[string]*tls.Certificate
|
staticCerts map[string]*tls.Certificate
|
||||||
|
resolvingDomains map[string]struct{}
|
||||||
acmeCertificates []*Certificate
|
acmeCertificates []*Certificate
|
||||||
domains []string
|
domains []string
|
||||||
expectedDomains []string
|
expectedDomains []string
|
||||||
|
@ -138,17 +139,55 @@ func TestGetUncheckedCertificates(t *testing.T) {
|
||||||
},
|
},
|
||||||
expectedDomains: []string{"traefik.wtf"},
|
expectedDomains: []string{"traefik.wtf"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "all domains already managed by ACME",
|
||||||
|
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||||
|
resolvingDomains: map[string]struct{}{
|
||||||
|
"traefik.wtf": {},
|
||||||
|
"foo.traefik.wtf": {},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "one domain already managed by ACME",
|
||||||
|
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||||
|
resolvingDomains: map[string]struct{}{
|
||||||
|
"traefik.wtf": {},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{"foo.traefik.wtf"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "wildcard domain already managed by ACME checks the domains",
|
||||||
|
domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"},
|
||||||
|
resolvingDomains: map[string]struct{}{
|
||||||
|
"*.traefik.wtf": {},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked",
|
||||||
|
domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"},
|
||||||
|
resolvingDomains: map[string]struct{}{
|
||||||
|
"*.traefik.wtf": {},
|
||||||
|
"traefik.wtf": {},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{"acme.wtf"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range testCases {
|
for _, test := range testCases {
|
||||||
test := test
|
test := test
|
||||||
|
if test.resolvingDomains == nil {
|
||||||
|
test.resolvingDomains = make(map[string]struct{})
|
||||||
|
}
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
acmeProvider := Provider{
|
acmeProvider := Provider{
|
||||||
dynamicCerts: test.dynamicCerts,
|
dynamicCerts: test.dynamicCerts,
|
||||||
staticCerts: test.staticCerts,
|
staticCerts: test.staticCerts,
|
||||||
certificates: test.acmeCertificates,
|
certificates: test.acmeCertificates,
|
||||||
|
resolvingDomains: test.resolvingDomains,
|
||||||
}
|
}
|
||||||
|
|
||||||
domains := acmeProvider.getUncheckedDomains(test.domains, false)
|
domains := acmeProvider.getUncheckedDomains(test.domains, false)
|
||||||
|
|
Loading…
Reference in a new issue