diff --git a/middlewares/request_host.go b/middlewares/request_host.go new file mode 100644 index 000000000..263b026ae --- /dev/null +++ b/middlewares/request_host.go @@ -0,0 +1,42 @@ +package middlewares + +import ( + "context" + "net" + "net/http" + "strings" + + "github.com/containous/traefik/types" +) + +var requestHostKey struct{} + +// RequestHost is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use. +type RequestHost struct{} + +func (rh *RequestHost) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if next != nil { + host := types.CanonicalDomain(parseHost(r.Host)) + next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), requestHostKey, host))) + } +} + +func parseHost(addr string) string { + if !strings.Contains(addr, ":") { + return addr + } + + host, _, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + return host +} + +// GetCanonizedHost plucks the canonized host key from the request of a context that was put through the middleware +func GetCanonizedHost(ctx context.Context) string { + if val, ok := ctx.Value(requestHostKey).(string); ok { + return val + } + return "" +} diff --git a/middlewares/request_host_test.go b/middlewares/request_host_test.go new file mode 100644 index 000000000..27116a2e7 --- /dev/null +++ b/middlewares/request_host_test.go @@ -0,0 +1,94 @@ +package middlewares + +import ( + "net/http" + "testing" + + "github.com/containous/traefik/testhelpers" + "github.com/stretchr/testify/assert" +) + +func TestRequestHost(t *testing.T) { + testCases := []struct { + desc string + url string + expected string + }{ + { + desc: "host without :", + url: "http://host", + expected: "host", + }, + { + desc: "host with : and without port", + url: "http://host:", + expected: "host", + }, + { + desc: "IP host with : and with port", + url: "http://127.0.0.1:123", + expected: "127.0.0.1", + }, + { + desc: "IP host with : and without port", + url: "http://127.0.0.1:", + expected: "127.0.0.1", + }, + } + + rh := &RequestHost{} + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil) + + rh.ServeHTTP(nil, req, func(_ http.ResponseWriter, r *http.Request) { + host := GetCanonizedHost(r.Context()) + assert.Equal(t, test.expected, host) + }) + }) + } +} + +func TestRequestHostParseHost(t *testing.T) { + testCases := []struct { + desc string + host string + expected string + }{ + { + desc: "host without :", + host: "host", + expected: "host", + }, + { + desc: "host with : and without port", + host: "host:", + expected: "host", + }, + { + desc: "IP host with : and with port", + host: "127.0.0.1:123", + expected: "127.0.0.1", + }, + { + desc: "IP host with : and without port", + host: "127.0.0.1:", + expected: "127.0.0.1", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + actual := parseHost(test.host) + + assert.Equal(t, test.expected, actual) + }) + } +} diff --git a/rules/rules.go b/rules/rules.go index 986c0dfe9..4a5658cd7 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -3,7 +3,6 @@ package rules import ( "errors" "fmt" - "net" "net/http" "reflect" "sort" @@ -13,6 +12,7 @@ import ( "github.com/containous/mux" "github.com/containous/traefik/hostresolver" "github.com/containous/traefik/log" + "github.com/containous/traefik/middlewares" "github.com/containous/traefik/types" ) @@ -24,17 +24,20 @@ type Rules struct { } func (r *Rules) host(hosts ...string) *mux.Route { + for i, host := range hosts { + hosts[i] = strings.ToLower(host) + } + return r.Route.Route.MatcherFunc(func(req *http.Request, route *mux.RouteMatch) bool { - reqHost, _, err := net.SplitHostPort(req.Host) - if err != nil { - reqHost = req.Host + reqHost := middlewares.GetCanonizedHost(req.Context()) + if len(reqHost) == 0 { + return false } if r.HostResolver != nil && r.HostResolver.CnameFlattening { - reqH, flatH := r.HostResolver.CNAMEFlatten(types.CanonicalDomain(reqHost)) + reqH, flatH := r.HostResolver.CNAMEFlatten(reqHost) for _, host := range hosts { - if types.CanonicalDomain(reqH) == types.CanonicalDomain(host) || - types.CanonicalDomain(flatH) == types.CanonicalDomain(host) { + if strings.EqualFold(reqH, host) || strings.EqualFold(flatH, host) { return true } log.Debugf("CNAMEFlattening: request %s which resolved to %s, is not matched to route %s", reqH, flatH, host) @@ -43,7 +46,7 @@ func (r *Rules) host(hosts ...string) *mux.Route { } for _, host := range hosts { - if types.CanonicalDomain(reqHost) == types.CanonicalDomain(host) { + if reqHost == host { return true } } @@ -54,7 +57,7 @@ func (r *Rules) host(hosts ...string) *mux.Route { func (r *Rules) hostRegexp(hosts ...string) *mux.Route { router := r.Route.Route.Subrouter() for _, host := range hosts { - router.Host(types.CanonicalDomain(host)) + router.Host(strings.ToLower(host)) } return r.Route.Route } @@ -62,7 +65,7 @@ func (r *Rules) hostRegexp(hosts ...string) *mux.Route { func (r *Rules) path(paths ...string) *mux.Route { router := r.Route.Route.Subrouter() for _, path := range paths { - router.Path(strings.TrimSpace(path)) + router.Path(path) } return r.Route.Route } @@ -76,14 +79,13 @@ func (r *Rules) pathPrefix(paths ...string) *mux.Route { } func buildPath(path string, router *mux.Router) { - cleanPath := strings.TrimSpace(path) // {} are used to define a regex pattern in http://www.gorillatoolkit.org/pkg/mux. // if we find a { in the path, that means we use regex, then the gorilla/mux implementation is chosen // otherwise, we use a lightweight implementation - if strings.Contains(cleanPath, "{") { - router.PathPrefix(cleanPath) + if strings.Contains(path, "{") { + router.PathPrefix(path) } else { - m := &prefixMatcher{prefix: cleanPath} + m := &prefixMatcher{prefix: path} router.NewRoute().MatcherFunc(m.Match) } } @@ -117,7 +119,7 @@ func (r *Rules) pathStripRegex(paths ...string) *mux.Route { r.Route.StripPrefixesRegex = paths router := r.Route.Route.Subrouter() for _, path := range paths { - router.Path(strings.TrimSpace(path)) + router.Path(path) } return r.Route.Route } @@ -158,7 +160,7 @@ func (r *Rules) pathPrefixStripRegex(paths ...string) *mux.Route { r.Route.StripPrefixesRegex = paths router := r.Route.Route.Subrouter() for _, path := range paths { - router.PathPrefix(strings.TrimSpace(path)) + router.PathPrefix(path) } return r.Route.Route } @@ -297,5 +299,5 @@ func (r *Rules) ParseDomains(expression string) ([]string, error) { return nil, fmt.Errorf("error parsing domains: %v", err) } - return fun.Map(types.CanonicalDomain, domains).([]string), nil + return fun.Map(strings.ToLower, domains).([]string), nil } diff --git a/rules/rules_test.go b/rules/rules_test.go index fe40f8e26..7578cedca 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/containous/mux" + "github.com/containous/traefik/middlewares" "github.com/containous/traefik/testhelpers" "github.com/containous/traefik/types" "github.com/stretchr/testify/assert" @@ -13,41 +14,50 @@ import ( ) func TestParseOneRule(t *testing.T) { - router := mux.NewRouter() - route := router.NewRoute() - serverRoute := &types.ServerRoute{Route: route} - rules := &Rules{Route: serverRoute} + reqHostMid := &middlewares.RequestHost{} + rules := &Rules{ + Route: &types.ServerRoute{ + Route: mux.NewRouter().NewRoute(), + }, + } expression := "Host:foo.bar" + routeResult, err := rules.Parse(expression) require.NoError(t, err, "Error while building route for %s", expression) request := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar", nil) - routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - assert.True(t, routeMatch, "Rule %s don't match.", expression) + reqHostMid.ServeHTTP(nil, request, func(w http.ResponseWriter, r *http.Request) { + routeMatch := routeResult.Match(r, &mux.RouteMatch{Route: routeResult}) + assert.True(t, routeMatch, "Rule %s don't match.", expression) + }) } func TestParseTwoRules(t *testing.T) { - router := mux.NewRouter() - route := router.NewRoute() - serverRoute := &types.ServerRoute{Route: route} - rules := &Rules{Route: serverRoute} + reqHostMid := &middlewares.RequestHost{} + rules := &Rules{ + Route: &types.ServerRoute{ + Route: mux.NewRouter().NewRoute(), + }, + } expression := "Host: Foo.Bar ; Path:/FOObar" - routeResult, err := rules.Parse(expression) + routeResult, err := rules.Parse(expression) require.NoError(t, err, "Error while building route for %s.", expression) request := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/foobar", nil) - routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - - assert.False(t, routeMatch, "Rule %s don't match.", expression) + reqHostMid.ServeHTTP(nil, request, func(w http.ResponseWriter, r *http.Request) { + routeMatch := routeResult.Match(r, &mux.RouteMatch{Route: routeResult}) + assert.False(t, routeMatch, "Rule %s don't match.", expression) + }) request = testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/FOObar", nil) - routeMatch = routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - - assert.True(t, routeMatch, "Rule %s don't match.", expression) + reqHostMid.ServeHTTP(nil, request, func(w http.ResponseWriter, r *http.Request) { + routeMatch := routeResult.Match(r, &mux.RouteMatch{Route: routeResult}) + assert.True(t, routeMatch, "Rule %s don't match.", expression) + }) } func TestParseDomains(t *testing.T) { @@ -91,6 +101,7 @@ func TestParseDomains(t *testing.T) { func TestPriorites(t *testing.T) { router := mux.NewRouter() router.StrictSlash(true) + rules := &Rules{Route: &types.ServerRoute{Route: router.NewRoute()}} expression01 := "PathPrefix:/foo" diff --git a/server/server_configuration_test.go b/server/server_configuration_test.go index 2a932f195..fa945c830 100644 --- a/server/server_configuration_test.go +++ b/server/server_configuration_test.go @@ -12,6 +12,7 @@ import ( "github.com/containous/mux" "github.com/containous/traefik/configuration" "github.com/containous/traefik/healthcheck" + "github.com/containous/traefik/middlewares" "github.com/containous/traefik/rules" th "github.com/containous/traefik/testhelpers" "github.com/containous/traefik/tls" @@ -385,6 +386,7 @@ func TestServerMultipleFrontendRules(t *testing.T) { router := mux.NewRouter() route := router.NewRoute() serverRoute := &types.ServerRoute{Route: route} + reqHostMid := &middlewares.RequestHost{} rls := &rules.Rules{Route: serverRoute} expression := test.expression @@ -395,7 +397,10 @@ func TestServerMultipleFrontendRules(t *testing.T) { } request := th.MustNewRequest(http.MethodGet, test.requestURL, nil) - routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) + var routeMatch bool + reqHostMid.ServeHTTP(nil, request, func(w http.ResponseWriter, r *http.Request) { + routeMatch = routeResult.Match(r, &mux.RouteMatch{Route: routeResult}) + }) if !routeMatch { t.Fatalf("Rule %s doesn't match", expression) diff --git a/server/server_middlewares.go b/server/server_middlewares.go index d3abdfb96..1a1a1a89e 100644 --- a/server/server_middlewares.go +++ b/server/server_middlewares.go @@ -172,6 +172,9 @@ func (s *Server) buildServerEntryPointMiddlewares(serverEntryPointName string, s serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", serverEntryPointName))) } + // RequestHost Cannonizer + serverMiddlewares = append(serverMiddlewares, &middlewares.RequestHost{}) + return serverMiddlewares, nil }