diff --git a/server/rules.go b/server/rules.go index 661b80224..be7f2b033 100644 --- a/server/rules.go +++ b/server/rules.go @@ -54,11 +54,32 @@ func (r *Rules) path(paths ...string) *mux.Route { func (r *Rules) pathPrefix(paths ...string) *mux.Route { router := r.route.route.Subrouter() for _, path := range paths { - router.PathPrefix(strings.TrimSpace(path)) + buildPath(path, router) } return r.route.route } +func buildPath(path string, router *mux.Router) { + cleanPath := strings.TrimSpace(path) + // {} are used to define a regex pattern in http://www.gorillatoolkit.org/pkg/mux. + // if we find a { in the path, that means we use regex, then the gorilla/mux implementation is chosen + // otherwise, we use a lightweight implementation + if strings.Contains(cleanPath, "{") { + router.PathPrefix(cleanPath) + } else { + m := &prefixMatcher{prefix: cleanPath} + router.NewRoute().MatcherFunc(m.Match) + } +} + +type prefixMatcher struct { + prefix string +} + +func (m *prefixMatcher) Match(r *http.Request, _ *mux.RouteMatch) bool { + return strings.HasPrefix(r.URL.Path, m.prefix) || strings.HasPrefix(r.URL.Path, m.prefix+"/") +} + type bySize []string func (a bySize) Len() int { return len(a) } @@ -111,7 +132,7 @@ func (r *Rules) pathPrefixStrip(paths ...string) *mux.Route { r.route.stripPrefixes = paths router := r.route.route.Subrouter() for _, path := range paths { - router.PathPrefix(strings.TrimSpace(path)) + buildPath(path, router) } return r.route.route } diff --git a/server/rules_test.go b/server/rules_test.go index f0ad65edd..7d3712289 100644 --- a/server/rules_test.go +++ b/server/rules_test.go @@ -192,3 +192,67 @@ type fakeHandler struct { } func (h *fakeHandler) ServeHTTP(http.ResponseWriter, *http.Request) {} + +func TestPathPrefix(t *testing.T) { + testCases := []struct { + desc string + path string + urls map[string]bool + }{ + { + desc: "leading slash", + path: "/bar", + urls: map[string]bool{ + "http://foo.com/bar": true, + "http://foo.com/bar/": true, + }, + }, + { + desc: "leading trailing slash", + path: "/bar/", + urls: map[string]bool{ + "http://foo.com/bar": false, + "http://foo.com/bar/": true, + }, + }, + { + desc: "no slash", + path: "bar", + urls: map[string]bool{ + "http://foo.com/bar": false, + "http://foo.com/bar/": false, + }, + }, + { + desc: "trailing slash", + path: "bar/", + urls: map[string]bool{ + "http://foo.com/bar": false, + "http://foo.com/bar/": false, + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rls := &Rules{ + route: &serverRoute{ + route: &mux.Route{}, + }, + } + + rt := rls.pathPrefix(test.path) + + for testURL, expectedMatch := range test.urls { + req := testhelpers.MustNewRequest(http.MethodGet, testURL, nil) + match := rt.Match(req, &mux.RouteMatch{}) + if match != expectedMatch { + t.Errorf("Error matching %s with %s, got %v expected %v", test.path, testURL, match, expectedMatch) + } + } + }) + } +}