From f1a05ab73cc29b16016e5ac83f3c39722394f4f0 Mon Sep 17 00:00:00 2001 From: Tait Clarridge Date: Tue, 27 Mar 2018 10:18:03 -0400 Subject: [PATCH] Add wildcard match to acme domains --- acme/account.go | 3 ++ acme/acme.go | 22 +++------- acme/acme_test.go | 92 ++++++++++++++++++++++++++++++++++++++- provider/acme/provider.go | 14 ++---- server/server.go | 9 ++-- types/domain_test.go | 92 +++++++++++++++++++++++++++++++++++++++ types/domains.go | 21 +++++++++ 7 files changed, 219 insertions(+), 34 deletions(-) diff --git a/acme/account.go b/acme/account.go index 3215257c8..e32e68948 100644 --- a/acme/account.go +++ b/acme/account.go @@ -219,6 +219,9 @@ func (dc *DomainsCertificates) getCertificateForDomain(domainToFind string) (*Do for _, domainsCertificate := range dc.Certs { for _, domain := range domainsCertificate.Domains.ToStrArray() { + if strings.HasPrefix(domain, "*.") && types.MatchDomain(domainToFind, domain) { + return domainsCertificate, true + } if domain == domainToFind { return domainsCertificate, true } diff --git a/acme/acme.go b/acme/acme.go index ec4da2ba6..8e8634654 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -11,7 +11,6 @@ import ( "net/http" "os" "reflect" - "regexp" "strings" "time" @@ -27,7 +26,7 @@ import ( "github.com/containous/traefik/tls/generate" "github.com/containous/traefik/types" "github.com/eapache/channels" - acme "github.com/xenolf/lego/acmev2" + "github.com/xenolf/lego/acmev2" "github.com/xenolf/lego/providers/dns" ) @@ -555,15 +554,14 @@ func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate { func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Certificate) *tls.Certificate { // Use regex to test for provided certs that might have been added into TLSConfig for certDomains := range certs { - domainCheck := false + domainChecked := false for _, certDomain := range strings.Split(certDomains, ",") { - selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.", -1) + "$" - domainCheck, _ = regexp.MatchString(selector, domain) - if domainCheck { + domainChecked = types.MatchDomain(domain, certDomain) + if domainChecked { break } } - if domainCheck { + if domainChecked { log.Debugf("Domain %q checked by provided certificate %q", domain, certDomains) return certs[certDomains] } @@ -684,15 +682,7 @@ func (a *ACME) getValidDomains(domains []string, wildcardAllowed bool) ([]string func isDomainAlreadyChecked(domainToCheck string, existentDomains map[string]*tls.Certificate) bool { for certDomains := range existentDomains { for _, certDomain := range strings.Split(certDomains, ",") { - // Use regex to test for provided existentDomains that might have been added into TLSConfig - selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.", -1) + "$" - domainCheck, err := regexp.MatchString(selector, domainToCheck) - if err != nil { - log.Errorf("Unable to compare %q and %q : %s", domainToCheck, certDomain, err) - continue - } - - if domainCheck { + if types.MatchDomain(domainToCheck, certDomain) { return true } } diff --git a/acme/acme_test.go b/acme/acme_test.go index 39ef373cb..076308c82 100644 --- a/acme/acme_test.go +++ b/acme/acme_test.go @@ -14,7 +14,7 @@ import ( "github.com/containous/traefik/tls/generate" "github.com/containous/traefik/types" "github.com/stretchr/testify/assert" - acme "github.com/xenolf/lego/acmev2" + "github.com/xenolf/lego/acmev2" ) func TestDomainsSet(t *testing.T) { @@ -444,3 +444,93 @@ func TestAcme_getValidDomain(t *testing.T) { }) } } + +func TestAcme_getCertificateForDomain(t *testing.T) { + testCases := []struct { + desc string + domain string + dc *DomainsCertificates + expected *DomainsCertificate + expectedFound bool + }{ + { + desc: "non-wildcard exact match", + domain: "foo.traefik.wtf", + dc: &DomainsCertificates{ + Certs: []*DomainsCertificate{ + { + Domains: types.Domain{ + Main: "foo.traefik.wtf", + }, + }, + }, + }, + expected: &DomainsCertificate{ + Domains: types.Domain{ + Main: "foo.traefik.wtf", + }, + }, + expectedFound: true, + }, + { + desc: "non-wildcard no match", + domain: "bar.traefik.wtf", + dc: &DomainsCertificates{ + Certs: []*DomainsCertificate{ + { + Domains: types.Domain{ + Main: "foo.traefik.wtf", + }, + }, + }, + }, + expected: nil, + expectedFound: false, + }, + { + desc: "wildcard match", + domain: "foo.traefik.wtf", + dc: &DomainsCertificates{ + Certs: []*DomainsCertificate{ + { + Domains: types.Domain{ + Main: "*.traefik.wtf", + }, + }, + }, + }, + expected: &DomainsCertificate{ + Domains: types.Domain{ + Main: "*.traefik.wtf", + }, + }, + expectedFound: true, + }, + { + desc: "wildcard no match", + domain: "foo.traefik.wtf", + dc: &DomainsCertificates{ + Certs: []*DomainsCertificate{ + { + Domains: types.Domain{ + Main: "*.bar.traefik.wtf", + }, + }, + }, + }, + expected: nil, + expectedFound: false, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + got, found := test.dc.getCertificateForDomain(test.domain) + assert.Equal(t, test.expectedFound, found) + assert.Equal(t, test.expected, got) + }) + } +} diff --git a/provider/acme/provider.go b/provider/acme/provider.go index 3be1f46b9..c45db2d77 100644 --- a/provider/acme/provider.go +++ b/provider/acme/provider.go @@ -10,7 +10,6 @@ import ( "net/http" "os" "reflect" - "regexp" "strings" "sync" "time" @@ -24,7 +23,7 @@ import ( traefikTLS "github.com/containous/traefik/tls" "github.com/containous/traefik/types" "github.com/pkg/errors" - acme "github.com/xenolf/lego/acmev2" + "github.com/xenolf/lego/acmev2" "github.com/xenolf/lego/providers/dns" ) @@ -522,7 +521,7 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati } func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) []string { - uncheckedDomains := []string{} + var uncheckedDomains []string for _, domainToCheck := range domainsToCheck { if !isDomainAlreadyChecked(domainToCheck, existentDomains) { uncheckedDomains = append(uncheckedDomains, domainToCheck) @@ -583,14 +582,7 @@ func (p *Provider) getValidDomains(domain types.Domain, wildcardAllowed bool) ([ func isDomainAlreadyChecked(domainToCheck string, existentDomains []string) bool { for _, certDomains := range existentDomains { for _, certDomain := range strings.Split(certDomains, ",") { - // Use regex to test for provided existentDomains that might have been added into TLSConfig - selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.", -1) + "$" - domainCheck, err := regexp.MatchString(selector, domainToCheck) - if err != nil { - log.Errorf("Unable to compare %q and %q in ACME provider : %s", domainToCheck, certDomain, err) - continue - } - if domainCheck { + if types.MatchDomain(domainToCheck, certDomain) { return true } } diff --git a/server/server.go b/server/server.go index 34b128bd9..b32b0cb2a 100644 --- a/server/server.go +++ b/server/server.go @@ -15,7 +15,6 @@ import ( "os" "os/signal" "reflect" - "regexp" "sort" "strings" "sync" @@ -517,15 +516,13 @@ func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, def return newEPCertificates, nil } -// getCertificate allows to customize tlsConfig.Getcertificate behaviour to get the certificates inserted dynamically +// getCertificate allows to customize tlsConfig.GetCertificate behaviour to get the certificates inserted dynamically func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { domainToCheck := types.CanonicalDomain(clientHello.ServerName) if s.certs.Get() != nil { for domains, cert := range s.certs.Get().(map[string]*tls.Certificate) { - for _, domain := range strings.Split(domains, ",") { - selector := "^" + strings.Replace(domain, "*.", "[^\\.]*\\.?", -1) + "$" - domainCheck, _ := regexp.MatchString(selector, domainToCheck) - if domainCheck { + for _, certDomain := range strings.Split(domains, ",") { + if types.MatchDomain(domainToCheck, certDomain) { return cert, nil } } diff --git a/types/domain_test.go b/types/domain_test.go index 911064a9c..dc97c7971 100644 --- a/types/domain_test.go +++ b/types/domain_test.go @@ -88,3 +88,95 @@ func TestDomain_Set(t *testing.T) { }) } } + +func TestMatchDomain(t *testing.T) { + testCases := []struct { + desc string + certDomain string + domain string + expected bool + }{ + { + desc: "exact match", + certDomain: "traefik.wtf", + domain: "traefik.wtf", + expected: true, + }, + { + desc: "wildcard and root domain", + certDomain: "*.traefik.wtf", + domain: "traefik.wtf", + expected: false, + }, + { + desc: "wildcard and sub domain", + certDomain: "*.traefik.wtf", + domain: "sub.traefik.wtf", + expected: true, + }, + { + desc: "wildcard and sub sub domain", + certDomain: "*.traefik.wtf", + domain: "sub.sub.traefik.wtf", + expected: false, + }, + { + desc: "double wildcard and sub sub domain", + certDomain: "*.*.traefik.wtf", + domain: "sub.sub.traefik.wtf", + expected: true, + }, + { + desc: "sub sub domain and invalid wildcard", + certDomain: "sub.*.traefik.wtf", + domain: "sub.sub.traefik.wtf", + expected: false, + }, + { + desc: "sub sub domain and valid wildcard", + certDomain: "*.sub.traefik.wtf", + domain: "sub.sub.traefik.wtf", + expected: true, + }, + { + desc: "dot replaced by a cahr", + certDomain: "sub.sub.traefik.wtf", + domain: "sub.sub.traefikiwtf", + expected: false, + }, + { + desc: "*", + certDomain: "*", + domain: "sub.sub.traefik.wtf", + expected: false, + }, + { + desc: "?", + certDomain: "?", + domain: "sub.sub.traefik.wtf", + expected: false, + }, + { + desc: "...................", + certDomain: "...................", + domain: "sub.sub.traefik.wtf", + expected: false, + }, + { + desc: "wildcard and *", + certDomain: "*.traefik.wtf", + domain: "*.*.traefik.wtf", + expected: false, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + domains := MatchDomain(test.domain, test.certDomain) + assert.Equal(t, test.expected, domains) + }) + } +} diff --git a/types/domains.go b/types/domains.go index 47bae5468..2cace3f64 100644 --- a/types/domains.go +++ b/types/domains.go @@ -65,3 +65,24 @@ func (ds *Domains) String() string { return fmt.Sprintf("%+v", *ds) } func (ds *Domains) SetValue(val interface{}) { *ds = val.([]Domain) } + +// MatchDomain return true if a domain match the cert domain +func MatchDomain(domain string, certDomain string) bool { + if domain == certDomain { + return true + } + + for len(certDomain) > 0 && certDomain[len(certDomain)-1] == '.' { + certDomain = certDomain[:len(certDomain)-1] + } + + labels := strings.Split(domain, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if certDomain == candidate { + return true + } + } + return false +}