Merge pull request #733 from containous/fix-case-sensitive-hosts

Fix case sensitive host
This commit is contained in:
Vincent Demeester 2016-10-17 15:44:09 +02:00 committed by GitHub
commit 4476861d9f
4 changed files with 33 additions and 19 deletions

View file

@ -4,11 +4,13 @@ import (
"crypto/tls"
"errors"
"fmt"
"github.com/BurntSushi/ty/fun"
"github.com/cenk/backoff"
"github.com/containous/staert"
"github.com/containous/traefik/cluster"
"github.com/containous/traefik/log"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/types"
"github.com/xenolf/lego/acme"
"golang.org/x/net/context"
"io/ioutil"
@ -311,22 +313,23 @@ func (a *ACME) CreateLocalConfig(tlsConfig *tls.Config, checkOnDemandDomain func
}
func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domain := types.CanonicalDomain(clientHello.ServerName)
account := a.store.Get().(*Account)
if challengeCert, ok := a.challengeProvider.getCertificate(clientHello.ServerName); ok {
log.Debugf("ACME got challenge %s", clientHello.ServerName)
if challengeCert, ok := a.challengeProvider.getCertificate(domain); ok {
log.Debugf("ACME got challenge %s", domain)
return challengeCert, nil
}
if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok {
log.Debugf("ACME got domain cert %s", clientHello.ServerName)
if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(domain); ok {
log.Debugf("ACME got domain cert %s", domain)
return domainCert.tlsCert, nil
}
if a.OnDemand {
if a.checkOnDemandDomain != nil && !a.checkOnDemandDomain(clientHello.ServerName) {
if a.checkOnDemandDomain != nil && !a.checkOnDemandDomain(domain) {
return nil, nil
}
return a.loadCertificateOnDemand(clientHello)
}
log.Debugf("ACME got nothing %s", clientHello.ServerName)
log.Debugf("ACME got nothing %s", domain)
return nil, nil
}
@ -429,22 +432,23 @@ func (a *ACME) buildACMEClient(account *Account) (*acme.Client, error) {
}
func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domain := types.CanonicalDomain(clientHello.ServerName)
account := a.store.Get().(*Account)
if certificateResource, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok {
if certificateResource, ok := account.DomainsCertificate.getCertificateForDomain(domain); ok {
return certificateResource.tlsCert, nil
}
certificate, err := a.getDomainsCertificates([]string{clientHello.ServerName})
certificate, err := a.getDomainsCertificates([]string{domain})
if err != nil {
return nil, err
}
log.Debugf("Got certificate on demand for domain %s", clientHello.ServerName)
log.Debugf("Got certificate on demand for domain %s", domain)
transaction, object, err := a.store.Begin()
if err != nil {
return nil, err
}
account = object.(*Account)
cert, err := account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: clientHello.ServerName})
cert, err := account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: domain})
if err != nil {
return nil, err
}
@ -456,6 +460,7 @@ func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.C
// LoadCertificateForDomains loads certificates from ACME for given domains
func (a *ACME) LoadCertificateForDomains(domains []string) {
domains = fun.Map(types.CanonicalDomain, domains).([]string)
safe.Go(func() {
operation := func() error {
if a.client == nil {
@ -514,6 +519,7 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
}
func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) {
domains = fun.Map(types.CanonicalDomain, domains).([]string)
log.Debugf("Loading ACME certificates %s...", domains)
bundle := true
certificate, failures := a.client.ObtainCertificate(domains, bundle, nil)

View file

@ -3,7 +3,9 @@ package main
import (
"errors"
"fmt"
"github.com/BurntSushi/ty/fun"
"github.com/containous/mux"
"github.com/containous/traefik/types"
"net"
"net/http"
"reflect"
@ -24,7 +26,7 @@ func (r *Rules) host(hosts ...string) *mux.Route {
reqHost = req.Host
}
for _, host := range hosts {
if reqHost == strings.TrimSpace(host) {
if types.CanonicalDomain(reqHost) == types.CanonicalDomain(host) {
return true
}
}
@ -35,7 +37,7 @@ func (r *Rules) host(hosts ...string) *mux.Route {
func (r *Rules) hostRegexp(hosts ...string) *mux.Route {
router := r.route.route.Subrouter()
for _, host := range hosts {
router.Host(strings.TrimSpace(host))
router.Host(types.CanonicalDomain(host))
}
return r.route.route
}
@ -43,7 +45,7 @@ func (r *Rules) hostRegexp(hosts ...string) *mux.Route {
func (r *Rules) path(paths ...string) *mux.Route {
router := r.route.route.Subrouter()
for _, path := range paths {
router.Path(strings.TrimSpace(path))
router.Path(types.CanonicalDomain(path))
}
return r.route.route
}
@ -51,7 +53,7 @@ func (r *Rules) path(paths ...string) *mux.Route {
func (r *Rules) pathPrefix(paths ...string) *mux.Route {
router := r.route.route.Subrouter()
for _, path := range paths {
router.PathPrefix(strings.TrimSpace(path))
router.PathPrefix(types.CanonicalDomain(path))
}
return r.route.route
}
@ -67,7 +69,7 @@ func (r *Rules) pathStrip(paths ...string) *mux.Route {
r.route.stripPrefixes = paths
router := r.route.route.Subrouter()
for _, path := range paths {
router.Path(strings.TrimSpace(path))
router.Path(types.CanonicalDomain(path))
}
return r.route.route
}
@ -77,7 +79,7 @@ func (r *Rules) pathPrefixStrip(paths ...string) *mux.Route {
r.route.stripPrefixes = paths
router := r.route.route.Subrouter()
for _, path := range paths {
router.PathPrefix(strings.TrimSpace(path))
router.PathPrefix(types.CanonicalDomain(path))
}
return r.route.route
}
@ -153,7 +155,6 @@ func (r *Rules) parseRules(expression string, onRule func(functionName string, f
}
}
return nil
}
// Parse parses rules expressions
@ -197,5 +198,5 @@ func (r *Rules) ParseDomains(expression string) ([]string, error) {
if err != nil {
return nil, fmt.Errorf("Error parsing domains: %v", err)
}
return domains, nil
return fun.Map(types.CanonicalDomain, domains).([]string), nil
}

View file

@ -36,7 +36,7 @@ func TestParseTwoRules(t *testing.T) {
serverRoute := &serverRoute{route: route}
rules := &Rules{route: serverRoute}
expression := "Host:foo.bar;Path:/foobar"
expression := "Host: Foo.Bar ; Path:/FOObar"
routeResult, err := rules.Parse(expression)
if err != nil {
@ -58,11 +58,13 @@ func TestParseDomains(t *testing.T) {
"Host:foo.bar,test.bar",
"Path:/test",
"Host:foo.bar;Path:/test",
"Host: Foo.Bar ;Path:/test",
}
domainsSlice := [][]string{
{"foo.bar", "test.bar"},
{},
{"foo.bar"},
{"foo.bar"},
}
for i, expression := range expressionsSlice {
domains, err := rules.ParseDomains(expression)

View file

@ -223,3 +223,8 @@ type Basic struct {
type Digest struct {
Users
}
// CanonicalDomain returns a lower case domain with trim space
func CanonicalDomain(domain string) string {
return strings.ToLower(strings.TrimSpace(domain))
}