Pass through certain forward auth negative response headers

This commit is contained in:
Kendrick Erickson 2017-11-02 05:06:03 -05:00 committed by Traefiker
parent 93a46089ce
commit 2b4d33e919
2 changed files with 108 additions and 1 deletions

View file

@ -14,7 +14,13 @@ import (
// Forward the authentication to a external server // Forward the authentication to a external server
func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
httpClient := http.Client{}
// Ensure our request client does not follow redirects
httpClient := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
if config.TLS != nil { if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig() tlsConfig, err := config.TLS.CreateTLSConfig()
@ -52,8 +58,30 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
} }
defer forwardResponse.Body.Close() defer forwardResponse.Body.Close()
// Pass the forward response's body and selected headers if it
// didn't return a response within the range of [200, 300).
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices { if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode) log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()
if err != nil {
if err != http.ErrNoLocation {
log.Debugf("Error reading response location header %s. Cause: %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
w.Header().Add("Location", redirectURL.String())
}
// Pass any Set-Cookie headers the forward auth server provides
for _, cookie := range forwardResponse.Cookies() {
w.Header().Add("Set-Cookie", cookie.String())
}
w.WriteHeader(forwardResponse.StatusCode) w.WriteHeader(forwardResponse.StatusCode)
w.Write(body) w.Write(body)
return return

View file

@ -77,6 +77,85 @@ func TestForwardAuthSuccess(t *testing.T) {
assert.Equal(t, "traefik\n", string(body), "they should be equal") assert.Equal(t, "traefik\n", string(body), "they should be equal")
} }
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
}
func TestForwardAuthCookie(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
http.SetCookie(w, cookie)
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value, "they should be equal")
}
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
}
func Test_writeHeader(t *testing.T) { func Test_writeHeader(t *testing.T) {
testCases := []struct { testCases := []struct {