Merge pull request #733 from containous/fix-case-sensitive-hosts

Fix case sensitive host
This commit is contained in:
Vincent Demeester 2016-10-17 15:44:09 +02:00 committed by GitHub
commit 4476861d9f
4 changed files with 33 additions and 19 deletions

View file

@ -4,11 +4,13 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"github.com/BurntSushi/ty/fun"
"github.com/cenk/backoff" "github.com/cenk/backoff"
"github.com/containous/staert" "github.com/containous/staert"
"github.com/containous/traefik/cluster" "github.com/containous/traefik/cluster"
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
"github.com/containous/traefik/safe" "github.com/containous/traefik/safe"
"github.com/containous/traefik/types"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/acme"
"golang.org/x/net/context" "golang.org/x/net/context"
"io/ioutil" "io/ioutil"
@ -311,22 +313,23 @@ func (a *ACME) CreateLocalConfig(tlsConfig *tls.Config, checkOnDemandDomain func
} }
func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (a *ACME) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domain := types.CanonicalDomain(clientHello.ServerName)
account := a.store.Get().(*Account) account := a.store.Get().(*Account)
if challengeCert, ok := a.challengeProvider.getCertificate(clientHello.ServerName); ok { if challengeCert, ok := a.challengeProvider.getCertificate(domain); ok {
log.Debugf("ACME got challenge %s", clientHello.ServerName) log.Debugf("ACME got challenge %s", domain)
return challengeCert, nil return challengeCert, nil
} }
if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(domain); ok {
log.Debugf("ACME got domain cert %s", clientHello.ServerName) log.Debugf("ACME got domain cert %s", domain)
return domainCert.tlsCert, nil return domainCert.tlsCert, nil
} }
if a.OnDemand { if a.OnDemand {
if a.checkOnDemandDomain != nil && !a.checkOnDemandDomain(clientHello.ServerName) { if a.checkOnDemandDomain != nil && !a.checkOnDemandDomain(domain) {
return nil, nil return nil, nil
} }
return a.loadCertificateOnDemand(clientHello) return a.loadCertificateOnDemand(clientHello)
} }
log.Debugf("ACME got nothing %s", clientHello.ServerName) log.Debugf("ACME got nothing %s", domain)
return nil, nil return nil, nil
} }
@ -429,22 +432,23 @@ func (a *ACME) buildACMEClient(account *Account) (*acme.Client, error) {
} }
func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domain := types.CanonicalDomain(clientHello.ServerName)
account := a.store.Get().(*Account) account := a.store.Get().(*Account)
if certificateResource, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { if certificateResource, ok := account.DomainsCertificate.getCertificateForDomain(domain); ok {
return certificateResource.tlsCert, nil return certificateResource.tlsCert, nil
} }
certificate, err := a.getDomainsCertificates([]string{clientHello.ServerName}) certificate, err := a.getDomainsCertificates([]string{domain})
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("Got certificate on demand for domain %s", clientHello.ServerName) log.Debugf("Got certificate on demand for domain %s", domain)
transaction, object, err := a.store.Begin() transaction, object, err := a.store.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
account = object.(*Account) account = object.(*Account)
cert, err := account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: clientHello.ServerName}) cert, err := account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: domain})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -456,6 +460,7 @@ func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.C
// LoadCertificateForDomains loads certificates from ACME for given domains // LoadCertificateForDomains loads certificates from ACME for given domains
func (a *ACME) LoadCertificateForDomains(domains []string) { func (a *ACME) LoadCertificateForDomains(domains []string) {
domains = fun.Map(types.CanonicalDomain, domains).([]string)
safe.Go(func() { safe.Go(func() {
operation := func() error { operation := func() error {
if a.client == nil { if a.client == nil {
@ -514,6 +519,7 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
} }
func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) { func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) {
domains = fun.Map(types.CanonicalDomain, domains).([]string)
log.Debugf("Loading ACME certificates %s...", domains) log.Debugf("Loading ACME certificates %s...", domains)
bundle := true bundle := true
certificate, failures := a.client.ObtainCertificate(domains, bundle, nil) certificate, failures := a.client.ObtainCertificate(domains, bundle, nil)

View file

@ -3,7 +3,9 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/BurntSushi/ty/fun"
"github.com/containous/mux" "github.com/containous/mux"
"github.com/containous/traefik/types"
"net" "net"
"net/http" "net/http"
"reflect" "reflect"
@ -24,7 +26,7 @@ func (r *Rules) host(hosts ...string) *mux.Route {
reqHost = req.Host reqHost = req.Host
} }
for _, host := range hosts { for _, host := range hosts {
if reqHost == strings.TrimSpace(host) { if types.CanonicalDomain(reqHost) == types.CanonicalDomain(host) {
return true return true
} }
} }
@ -35,7 +37,7 @@ func (r *Rules) host(hosts ...string) *mux.Route {
func (r *Rules) hostRegexp(hosts ...string) *mux.Route { func (r *Rules) hostRegexp(hosts ...string) *mux.Route {
router := r.route.route.Subrouter() router := r.route.route.Subrouter()
for _, host := range hosts { for _, host := range hosts {
router.Host(strings.TrimSpace(host)) router.Host(types.CanonicalDomain(host))
} }
return r.route.route return r.route.route
} }
@ -43,7 +45,7 @@ func (r *Rules) hostRegexp(hosts ...string) *mux.Route {
func (r *Rules) path(paths ...string) *mux.Route { func (r *Rules) path(paths ...string) *mux.Route {
router := r.route.route.Subrouter() router := r.route.route.Subrouter()
for _, path := range paths { for _, path := range paths {
router.Path(strings.TrimSpace(path)) router.Path(types.CanonicalDomain(path))
} }
return r.route.route return r.route.route
} }
@ -51,7 +53,7 @@ func (r *Rules) path(paths ...string) *mux.Route {
func (r *Rules) pathPrefix(paths ...string) *mux.Route { func (r *Rules) pathPrefix(paths ...string) *mux.Route {
router := r.route.route.Subrouter() router := r.route.route.Subrouter()
for _, path := range paths { for _, path := range paths {
router.PathPrefix(strings.TrimSpace(path)) router.PathPrefix(types.CanonicalDomain(path))
} }
return r.route.route return r.route.route
} }
@ -67,7 +69,7 @@ func (r *Rules) pathStrip(paths ...string) *mux.Route {
r.route.stripPrefixes = paths r.route.stripPrefixes = paths
router := r.route.route.Subrouter() router := r.route.route.Subrouter()
for _, path := range paths { for _, path := range paths {
router.Path(strings.TrimSpace(path)) router.Path(types.CanonicalDomain(path))
} }
return r.route.route return r.route.route
} }
@ -77,7 +79,7 @@ func (r *Rules) pathPrefixStrip(paths ...string) *mux.Route {
r.route.stripPrefixes = paths r.route.stripPrefixes = paths
router := r.route.route.Subrouter() router := r.route.route.Subrouter()
for _, path := range paths { for _, path := range paths {
router.PathPrefix(strings.TrimSpace(path)) router.PathPrefix(types.CanonicalDomain(path))
} }
return r.route.route return r.route.route
} }
@ -153,7 +155,6 @@ func (r *Rules) parseRules(expression string, onRule func(functionName string, f
} }
} }
return nil return nil
} }
// Parse parses rules expressions // Parse parses rules expressions
@ -197,5 +198,5 @@ func (r *Rules) ParseDomains(expression string) ([]string, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Error parsing domains: %v", err) return nil, fmt.Errorf("Error parsing domains: %v", err)
} }
return domains, nil return fun.Map(types.CanonicalDomain, domains).([]string), nil
} }

View file

@ -36,7 +36,7 @@ func TestParseTwoRules(t *testing.T) {
serverRoute := &serverRoute{route: route} serverRoute := &serverRoute{route: route}
rules := &Rules{route: serverRoute} rules := &Rules{route: serverRoute}
expression := "Host:foo.bar;Path:/foobar" expression := "Host: Foo.Bar ; Path:/FOObar"
routeResult, err := rules.Parse(expression) routeResult, err := rules.Parse(expression)
if err != nil { if err != nil {
@ -58,11 +58,13 @@ func TestParseDomains(t *testing.T) {
"Host:foo.bar,test.bar", "Host:foo.bar,test.bar",
"Path:/test", "Path:/test",
"Host:foo.bar;Path:/test", "Host:foo.bar;Path:/test",
"Host: Foo.Bar ;Path:/test",
} }
domainsSlice := [][]string{ domainsSlice := [][]string{
{"foo.bar", "test.bar"}, {"foo.bar", "test.bar"},
{}, {},
{"foo.bar"}, {"foo.bar"},
{"foo.bar"},
} }
for i, expression := range expressionsSlice { for i, expression := range expressionsSlice {
domains, err := rules.ParseDomains(expression) domains, err := rules.ParseDomains(expression)

View file

@ -223,3 +223,8 @@ type Basic struct {
type Digest struct { type Digest struct {
Users Users
} }
// CanonicalDomain returns a lower case domain with trim space
func CanonicalDomain(domain string) string {
return strings.ToLower(strings.TrimSpace(domain))
}