From e12ddca1a5755363f2071b2ca796e48cea64a29b Mon Sep 17 00:00:00 2001 From: Emile Vauge Date: Fri, 14 Oct 2016 16:04:09 +0200 Subject: [PATCH] Fix case sensitive host --- acme/acme.go | 26 ++++++++++++++++---------- rules.go | 17 +++++++++-------- rules_test.go | 4 +++- types/types.go | 5 +++++ 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/acme/acme.go b/acme/acme.go index 32bebf37a..0c88caae3 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -4,11 +4,13 @@ import ( "crypto/tls" "errors" "fmt" + "github.com/BurntSushi/ty/fun" "github.com/cenk/backoff" "github.com/containous/staert" "github.com/containous/traefik/cluster" "github.com/containous/traefik/log" "github.com/containous/traefik/safe" + "github.com/containous/traefik/types" "github.com/xenolf/lego/acme" "golang.org/x/net/context" "io/ioutil" @@ -311,22 +313,23 @@ func (a *ACME) CreateLocalConfig(tlsConfig *tls.Config, checkOnDemandDomain func } func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + domain := types.CanonicalDomain(clientHello.ServerName) account := a.store.Get().(*Account) - if challengeCert, ok := a.challengeProvider.getCertificate(clientHello.ServerName); ok { - log.Debugf("ACME got challenge %s", clientHello.ServerName) + if challengeCert, ok := a.challengeProvider.getCertificate(domain); ok { + log.Debugf("ACME got challenge %s", domain) return challengeCert, nil } - if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { - log.Debugf("ACME got domain cert %s", clientHello.ServerName) + if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(domain); ok { + log.Debugf("ACME got domain cert %s", domain) return domainCert.tlsCert, nil } if a.OnDemand { - if a.checkOnDemandDomain != nil && !a.checkOnDemandDomain(clientHello.ServerName) { + if a.checkOnDemandDomain != nil && !a.checkOnDemandDomain(domain) { return nil, nil } return a.loadCertificateOnDemand(clientHello) } - log.Debugf("ACME got nothing %s", clientHello.ServerName) + log.Debugf("ACME got nothing %s", domain) return nil, nil } @@ -429,22 +432,23 @@ func (a *ACME) buildACMEClient(account *Account) (*acme.Client, error) { } func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + domain := types.CanonicalDomain(clientHello.ServerName) account := a.store.Get().(*Account) - if certificateResource, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { + if certificateResource, ok := account.DomainsCertificate.getCertificateForDomain(domain); ok { return certificateResource.tlsCert, nil } - certificate, err := a.getDomainsCertificates([]string{clientHello.ServerName}) + certificate, err := a.getDomainsCertificates([]string{domain}) if err != nil { return nil, err } - log.Debugf("Got certificate on demand for domain %s", clientHello.ServerName) + log.Debugf("Got certificate on demand for domain %s", domain) transaction, object, err := a.store.Begin() if err != nil { return nil, err } account = object.(*Account) - cert, err := account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: clientHello.ServerName}) + cert, err := account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: domain}) if err != nil { return nil, err } @@ -456,6 +460,7 @@ func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.C // LoadCertificateForDomains loads certificates from ACME for given domains func (a *ACME) LoadCertificateForDomains(domains []string) { + domains = fun.Map(types.CanonicalDomain, domains).([]string) safe.Go(func() { operation := func() error { if a.client == nil { @@ -514,6 +519,7 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { } func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) { + domains = fun.Map(types.CanonicalDomain, domains).([]string) log.Debugf("Loading ACME certificates %s...", domains) bundle := true certificate, failures := a.client.ObtainCertificate(domains, bundle, nil) diff --git a/rules.go b/rules.go index c091e0fa6..96cc73b6e 100644 --- a/rules.go +++ b/rules.go @@ -3,7 +3,9 @@ package main import ( "errors" "fmt" + "github.com/BurntSushi/ty/fun" "github.com/containous/mux" + "github.com/containous/traefik/types" "net" "net/http" "reflect" @@ -24,7 +26,7 @@ func (r *Rules) host(hosts ...string) *mux.Route { reqHost = req.Host } for _, host := range hosts { - if reqHost == strings.TrimSpace(host) { + if types.CanonicalDomain(reqHost) == types.CanonicalDomain(host) { return true } } @@ -35,7 +37,7 @@ func (r *Rules) host(hosts ...string) *mux.Route { func (r *Rules) hostRegexp(hosts ...string) *mux.Route { router := r.route.route.Subrouter() for _, host := range hosts { - router.Host(strings.TrimSpace(host)) + router.Host(types.CanonicalDomain(host)) } return r.route.route } @@ -43,7 +45,7 @@ func (r *Rules) hostRegexp(hosts ...string) *mux.Route { func (r *Rules) path(paths ...string) *mux.Route { router := r.route.route.Subrouter() for _, path := range paths { - router.Path(strings.TrimSpace(path)) + router.Path(types.CanonicalDomain(path)) } return r.route.route } @@ -51,7 +53,7 @@ func (r *Rules) path(paths ...string) *mux.Route { func (r *Rules) pathPrefix(paths ...string) *mux.Route { router := r.route.route.Subrouter() for _, path := range paths { - router.PathPrefix(strings.TrimSpace(path)) + router.PathPrefix(types.CanonicalDomain(path)) } return r.route.route } @@ -67,7 +69,7 @@ func (r *Rules) pathStrip(paths ...string) *mux.Route { r.route.stripPrefixes = paths router := r.route.route.Subrouter() for _, path := range paths { - router.Path(strings.TrimSpace(path)) + router.Path(types.CanonicalDomain(path)) } return r.route.route } @@ -77,7 +79,7 @@ func (r *Rules) pathPrefixStrip(paths ...string) *mux.Route { r.route.stripPrefixes = paths router := r.route.route.Subrouter() for _, path := range paths { - router.PathPrefix(strings.TrimSpace(path)) + router.PathPrefix(types.CanonicalDomain(path)) } return r.route.route } @@ -153,7 +155,6 @@ func (r *Rules) parseRules(expression string, onRule func(functionName string, f } } return nil - } // Parse parses rules expressions @@ -197,5 +198,5 @@ func (r *Rules) ParseDomains(expression string) ([]string, error) { if err != nil { return nil, fmt.Errorf("Error parsing domains: %v", err) } - return domains, nil + return fun.Map(types.CanonicalDomain, domains).([]string), nil } diff --git a/rules_test.go b/rules_test.go index 2bb89a347..694fde089 100644 --- a/rules_test.go +++ b/rules_test.go @@ -36,7 +36,7 @@ func TestParseTwoRules(t *testing.T) { serverRoute := &serverRoute{route: route} rules := &Rules{route: serverRoute} - expression := "Host:foo.bar;Path:/foobar" + expression := "Host: Foo.Bar ; Path:/FOObar" routeResult, err := rules.Parse(expression) if err != nil { @@ -58,11 +58,13 @@ func TestParseDomains(t *testing.T) { "Host:foo.bar,test.bar", "Path:/test", "Host:foo.bar;Path:/test", + "Host: Foo.Bar ;Path:/test", } domainsSlice := [][]string{ {"foo.bar", "test.bar"}, {}, {"foo.bar"}, + {"foo.bar"}, } for i, expression := range expressionsSlice { domains, err := rules.ParseDomains(expression) diff --git a/types/types.go b/types/types.go index b9d8298f8..76fa96bd4 100644 --- a/types/types.go +++ b/types/types.go @@ -223,3 +223,8 @@ type Basic struct { type Digest struct { Users } + +// CanonicalDomain returns a lower case domain with trim space +func CanonicalDomain(domain string) string { + return strings.ToLower(strings.TrimSpace(domain)) +}