Fix raw path handling in strip prefix

This commit is contained in:
Marco Jantke 2017-11-21 14:28:03 +01:00 committed by Traefiker
parent c9129b8ecf
commit 676b79db42
4 changed files with 47 additions and 6 deletions

View file

@ -16,8 +16,11 @@ type StripPrefix struct {
func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for _, prefix := range s.Prefixes { for _, prefix := range s.Prefixes {
if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { if strings.HasPrefix(r.URL.Path, prefix) {
r.URL.Path = "/" + strings.TrimPrefix(p, "/") r.URL.Path = stripPrefix(r.URL.Path, prefix)
if r.URL.RawPath != "" {
r.URL.RawPath = stripPrefix(r.URL.RawPath, prefix)
}
s.serveRequest(w, r, strings.TrimSpace(prefix)) s.serveRequest(w, r, strings.TrimSpace(prefix))
return return
} }
@ -35,3 +38,11 @@ func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefi
func (s *StripPrefix) SetHandler(Handler http.Handler) { func (s *StripPrefix) SetHandler(Handler http.Handler) {
s.Handler = Handler s.Handler = Handler
} }
func stripPrefix(s, prefix string) string {
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -40,6 +40,9 @@ func (s *StripPrefixRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
r.URL.Path = r.URL.Path[len(prefix.Path):] r.URL.Path = r.URL.Path[len(prefix.Path):]
if r.URL.RawPath != "" {
r.URL.RawPath = r.URL.RawPath[len(prefix.Path):]
}
r.Header.Add(ForwardedPrefixHeader, prefix.Path) r.Header.Add(ForwardedPrefixHeader, prefix.Path)
r.RequestURI = r.URL.RequestURI() r.RequestURI = r.URL.RequestURI()
s.Handler.ServeHTTP(w, r) s.Handler.ServeHTTP(w, r)

View file

@ -10,13 +10,13 @@ import (
) )
func TestStripPrefixRegex(t *testing.T) { func TestStripPrefixRegex(t *testing.T) {
testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"} testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"}
tests := []struct { tests := []struct {
path string path string
expectedStatusCode int expectedStatusCode int
expectedPath string expectedPath string
expectedRawPath string
expectedHeader string expectedHeader string
}{ }{
{ {
@ -61,6 +61,13 @@ func TestStripPrefixRegex(t *testing.T) {
path: "/c/api/abc/test4", path: "/c/api/abc/test4",
expectedStatusCode: http.StatusNotFound, expectedStatusCode: http.StatusNotFound,
}, },
{
path: "/a/api/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "a/b",
expectedRawPath: "a%2Fb",
expectedHeader: "/a/api/",
},
} }
for _, test := range tests { for _, test := range tests {
@ -68,9 +75,10 @@ func TestStripPrefixRegex(t *testing.T) {
t.Run(test.path, func(t *testing.T) { t.Run(test.path, func(t *testing.T) {
t.Parallel() t.Parallel()
var actualPath, actualHeader string var actualPath, actualRawPath, actualHeader string
handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader) actualHeader = r.Header.Get(ForwardedPrefixHeader)
}) })
handler := NewStripPrefixRegex(handlerPath, testPrefixRegex) handler := NewStripPrefixRegex(handlerPath, testPrefixRegex)
@ -82,6 +90,7 @@ func TestStripPrefixRegex(t *testing.T) {
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.") assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.") assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader) assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
}) })
} }

View file

@ -16,6 +16,7 @@ func TestStripPrefix(t *testing.T) {
path string path string
expectedStatusCode int expectedStatusCode int
expectedPath string expectedPath string
expectedRawPath string
expectedHeader string expectedHeader string
}{ }{
{ {
@ -94,6 +95,15 @@ func TestStripPrefix(t *testing.T) {
expectedPath: "/us", expectedPath: "/us",
expectedHeader: "/stat", expectedHeader: "/stat",
}, },
{
desc: "raw path is also stripped",
prefixes: []string{"/stat"},
path: "/stat/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "/a/b",
expectedRawPath: "/a%2Fb",
expectedHeader: "/stat",
},
} }
for _, test := range tests { for _, test := range tests {
@ -101,11 +111,12 @@ func TestStripPrefix(t *testing.T) {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
t.Parallel() t.Parallel()
var actualPath, actualHeader, requestURI string var actualPath, actualRawPath, actualHeader, requestURI string
handler := &StripPrefix{ handler := &StripPrefix{
Prefixes: test.prefixes, Prefixes: test.prefixes,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader) actualHeader = r.Header.Get(ForwardedPrefixHeader)
requestURI = r.RequestURI requestURI = r.RequestURI
}), }),
@ -118,8 +129,15 @@ func TestStripPrefix(t *testing.T) {
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.") assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.") assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader) assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
assert.Equal(t, test.expectedPath, requestURI, "Unexpected request URI.")
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
}) })
} }
} }