Add wildcard match to acme domains

This commit is contained in:
Tait Clarridge 2018-03-27 10:18:03 -04:00 committed by Traefiker Bot
parent 4c85a41bfb
commit f1a05ab73c
7 changed files with 219 additions and 34 deletions

View file

@ -219,6 +219,9 @@ func (dc *DomainsCertificates) getCertificateForDomain(domainToFind string) (*Do
for _, domainsCertificate := range dc.Certs { for _, domainsCertificate := range dc.Certs {
for _, domain := range domainsCertificate.Domains.ToStrArray() { for _, domain := range domainsCertificate.Domains.ToStrArray() {
if strings.HasPrefix(domain, "*.") && types.MatchDomain(domainToFind, domain) {
return domainsCertificate, true
}
if domain == domainToFind { if domain == domainToFind {
return domainsCertificate, true return domainsCertificate, true
} }

View file

@ -11,7 +11,6 @@ import (
"net/http" "net/http"
"os" "os"
"reflect" "reflect"
"regexp"
"strings" "strings"
"time" "time"
@ -27,7 +26,7 @@ import (
"github.com/containous/traefik/tls/generate" "github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/eapache/channels" "github.com/eapache/channels"
acme "github.com/xenolf/lego/acmev2" "github.com/xenolf/lego/acmev2"
"github.com/xenolf/lego/providers/dns" "github.com/xenolf/lego/providers/dns"
) )
@ -555,15 +554,14 @@ func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate {
func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Certificate) *tls.Certificate { func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Certificate) *tls.Certificate {
// Use regex to test for provided certs that might have been added into TLSConfig // Use regex to test for provided certs that might have been added into TLSConfig
for certDomains := range certs { for certDomains := range certs {
domainCheck := false domainChecked := false
for _, certDomain := range strings.Split(certDomains, ",") { for _, certDomain := range strings.Split(certDomains, ",") {
selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.", -1) + "$" domainChecked = types.MatchDomain(domain, certDomain)
domainCheck, _ = regexp.MatchString(selector, domain) if domainChecked {
if domainCheck {
break break
} }
} }
if domainCheck { if domainChecked {
log.Debugf("Domain %q checked by provided certificate %q", domain, certDomains) log.Debugf("Domain %q checked by provided certificate %q", domain, certDomains)
return certs[certDomains] return certs[certDomains]
} }
@ -684,15 +682,7 @@ func (a *ACME) getValidDomains(domains []string, wildcardAllowed bool) ([]string
func isDomainAlreadyChecked(domainToCheck string, existentDomains map[string]*tls.Certificate) bool { func isDomainAlreadyChecked(domainToCheck string, existentDomains map[string]*tls.Certificate) bool {
for certDomains := range existentDomains { for certDomains := range existentDomains {
for _, certDomain := range strings.Split(certDomains, ",") { for _, certDomain := range strings.Split(certDomains, ",") {
// Use regex to test for provided existentDomains that might have been added into TLSConfig if types.MatchDomain(domainToCheck, certDomain) {
selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.", -1) + "$"
domainCheck, err := regexp.MatchString(selector, domainToCheck)
if err != nil {
log.Errorf("Unable to compare %q and %q : %s", domainToCheck, certDomain, err)
continue
}
if domainCheck {
return true return true
} }
} }

View file

@ -14,7 +14,7 @@ import (
"github.com/containous/traefik/tls/generate" "github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
acme "github.com/xenolf/lego/acmev2" "github.com/xenolf/lego/acmev2"
) )
func TestDomainsSet(t *testing.T) { func TestDomainsSet(t *testing.T) {
@ -444,3 +444,93 @@ func TestAcme_getValidDomain(t *testing.T) {
}) })
} }
} }
func TestAcme_getCertificateForDomain(t *testing.T) {
testCases := []struct {
desc string
domain string
dc *DomainsCertificates
expected *DomainsCertificate
expectedFound bool
}{
{
desc: "non-wildcard exact match",
domain: "foo.traefik.wtf",
dc: &DomainsCertificates{
Certs: []*DomainsCertificate{
{
Domains: types.Domain{
Main: "foo.traefik.wtf",
},
},
},
},
expected: &DomainsCertificate{
Domains: types.Domain{
Main: "foo.traefik.wtf",
},
},
expectedFound: true,
},
{
desc: "non-wildcard no match",
domain: "bar.traefik.wtf",
dc: &DomainsCertificates{
Certs: []*DomainsCertificate{
{
Domains: types.Domain{
Main: "foo.traefik.wtf",
},
},
},
},
expected: nil,
expectedFound: false,
},
{
desc: "wildcard match",
domain: "foo.traefik.wtf",
dc: &DomainsCertificates{
Certs: []*DomainsCertificate{
{
Domains: types.Domain{
Main: "*.traefik.wtf",
},
},
},
},
expected: &DomainsCertificate{
Domains: types.Domain{
Main: "*.traefik.wtf",
},
},
expectedFound: true,
},
{
desc: "wildcard no match",
domain: "foo.traefik.wtf",
dc: &DomainsCertificates{
Certs: []*DomainsCertificate{
{
Domains: types.Domain{
Main: "*.bar.traefik.wtf",
},
},
},
},
expected: nil,
expectedFound: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
got, found := test.dc.getCertificateForDomain(test.domain)
assert.Equal(t, test.expectedFound, found)
assert.Equal(t, test.expected, got)
})
}
}

View file

@ -10,7 +10,6 @@ import (
"net/http" "net/http"
"os" "os"
"reflect" "reflect"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -24,7 +23,7 @@ import (
traefikTLS "github.com/containous/traefik/tls" traefikTLS "github.com/containous/traefik/tls"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/pkg/errors" "github.com/pkg/errors"
acme "github.com/xenolf/lego/acmev2" "github.com/xenolf/lego/acmev2"
"github.com/xenolf/lego/providers/dns" "github.com/xenolf/lego/providers/dns"
) )
@ -522,7 +521,7 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati
} }
func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) []string { func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) []string {
uncheckedDomains := []string{} var uncheckedDomains []string
for _, domainToCheck := range domainsToCheck { for _, domainToCheck := range domainsToCheck {
if !isDomainAlreadyChecked(domainToCheck, existentDomains) { if !isDomainAlreadyChecked(domainToCheck, existentDomains) {
uncheckedDomains = append(uncheckedDomains, domainToCheck) uncheckedDomains = append(uncheckedDomains, domainToCheck)
@ -583,14 +582,7 @@ func (p *Provider) getValidDomains(domain types.Domain, wildcardAllowed bool) ([
func isDomainAlreadyChecked(domainToCheck string, existentDomains []string) bool { func isDomainAlreadyChecked(domainToCheck string, existentDomains []string) bool {
for _, certDomains := range existentDomains { for _, certDomains := range existentDomains {
for _, certDomain := range strings.Split(certDomains, ",") { for _, certDomain := range strings.Split(certDomains, ",") {
// Use regex to test for provided existentDomains that might have been added into TLSConfig if types.MatchDomain(domainToCheck, certDomain) {
selector := "^" + strings.Replace(certDomain, "*.", "[^\\.]*\\.", -1) + "$"
domainCheck, err := regexp.MatchString(selector, domainToCheck)
if err != nil {
log.Errorf("Unable to compare %q and %q in ACME provider : %s", domainToCheck, certDomain, err)
continue
}
if domainCheck {
return true return true
} }
} }

View file

@ -15,7 +15,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"reflect" "reflect"
"regexp"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -517,15 +516,13 @@ func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, def
return newEPCertificates, nil return newEPCertificates, nil
} }
// getCertificate allows to customize tlsConfig.Getcertificate behaviour to get the certificates inserted dynamically // getCertificate allows to customize tlsConfig.GetCertificate behaviour to get the certificates inserted dynamically
func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domainToCheck := types.CanonicalDomain(clientHello.ServerName) domainToCheck := types.CanonicalDomain(clientHello.ServerName)
if s.certs.Get() != nil { if s.certs.Get() != nil {
for domains, cert := range s.certs.Get().(map[string]*tls.Certificate) { for domains, cert := range s.certs.Get().(map[string]*tls.Certificate) {
for _, domain := range strings.Split(domains, ",") { for _, certDomain := range strings.Split(domains, ",") {
selector := "^" + strings.Replace(domain, "*.", "[^\\.]*\\.?", -1) + "$" if types.MatchDomain(domainToCheck, certDomain) {
domainCheck, _ := regexp.MatchString(selector, domainToCheck)
if domainCheck {
return cert, nil return cert, nil
} }
} }

View file

@ -88,3 +88,95 @@ func TestDomain_Set(t *testing.T) {
}) })
} }
} }
func TestMatchDomain(t *testing.T) {
testCases := []struct {
desc string
certDomain string
domain string
expected bool
}{
{
desc: "exact match",
certDomain: "traefik.wtf",
domain: "traefik.wtf",
expected: true,
},
{
desc: "wildcard and root domain",
certDomain: "*.traefik.wtf",
domain: "traefik.wtf",
expected: false,
},
{
desc: "wildcard and sub domain",
certDomain: "*.traefik.wtf",
domain: "sub.traefik.wtf",
expected: true,
},
{
desc: "wildcard and sub sub domain",
certDomain: "*.traefik.wtf",
domain: "sub.sub.traefik.wtf",
expected: false,
},
{
desc: "double wildcard and sub sub domain",
certDomain: "*.*.traefik.wtf",
domain: "sub.sub.traefik.wtf",
expected: true,
},
{
desc: "sub sub domain and invalid wildcard",
certDomain: "sub.*.traefik.wtf",
domain: "sub.sub.traefik.wtf",
expected: false,
},
{
desc: "sub sub domain and valid wildcard",
certDomain: "*.sub.traefik.wtf",
domain: "sub.sub.traefik.wtf",
expected: true,
},
{
desc: "dot replaced by a cahr",
certDomain: "sub.sub.traefik.wtf",
domain: "sub.sub.traefikiwtf",
expected: false,
},
{
desc: "*",
certDomain: "*",
domain: "sub.sub.traefik.wtf",
expected: false,
},
{
desc: "?",
certDomain: "?",
domain: "sub.sub.traefik.wtf",
expected: false,
},
{
desc: "...................",
certDomain: "...................",
domain: "sub.sub.traefik.wtf",
expected: false,
},
{
desc: "wildcard and *",
certDomain: "*.traefik.wtf",
domain: "*.*.traefik.wtf",
expected: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
domains := MatchDomain(test.domain, test.certDomain)
assert.Equal(t, test.expected, domains)
})
}
}

View file

@ -65,3 +65,24 @@ func (ds *Domains) String() string { return fmt.Sprintf("%+v", *ds) }
func (ds *Domains) SetValue(val interface{}) { func (ds *Domains) SetValue(val interface{}) {
*ds = val.([]Domain) *ds = val.([]Domain)
} }
// MatchDomain return true if a domain match the cert domain
func MatchDomain(domain string, certDomain string) bool {
if domain == certDomain {
return true
}
for len(certDomain) > 0 && certDomain[len(certDomain)-1] == '.' {
certDomain = certDomain[:len(certDomain)-1]
}
labels := strings.Split(domain, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if certDomain == candidate {
return true
}
}
return false
}