Forward auth: copy response headers when auth failed.
This commit is contained in:
parent
76dcbe3429
commit
2975acdc82
2 changed files with 24 additions and 10 deletions
|
@ -25,6 +25,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
||||||
return http.ErrUseLastResponse
|
return http.ErrUseLastResponse
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.TLS != nil {
|
if config.TLS != nil {
|
||||||
tlsConfig, err := config.TLS.CreateTLSConfig()
|
tlsConfig, err := config.TLS.CreateTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -32,10 +33,12 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient.Transport = &http.Transport{
|
httpClient.Transport = &http.Transport{
|
||||||
TLSClientConfig: tlsConfig,
|
TLSClientConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardReq, err := http.NewRequest(http.MethodGet, config.Address, nil)
|
forwardReq, err := http.NewRequest(http.MethodGet, config.Address, nil)
|
||||||
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
|
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -68,6 +71,8 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
||||||
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
|
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
|
||||||
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
|
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
|
||||||
|
|
||||||
|
utils.CopyHeaders(w.Header(), forwardResponse.Header)
|
||||||
|
|
||||||
// Grab the location header, if any.
|
// Grab the location header, if any.
|
||||||
redirectURL, err := forwardResponse.Location()
|
redirectURL, err := forwardResponse.Location()
|
||||||
|
|
||||||
|
@ -79,12 +84,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
||||||
}
|
}
|
||||||
} else if redirectURL.String() != "" {
|
} else if redirectURL.String() != "" {
|
||||||
// Set the location in our response if one was sent back.
|
// Set the location in our response if one was sent back.
|
||||||
w.Header().Add("Location", redirectURL.String())
|
w.Header().Set("Location", redirectURL.String())
|
||||||
}
|
|
||||||
|
|
||||||
// Pass any Set-Cookie headers the forward auth server provides
|
|
||||||
for _, cookie := range forwardResponse.Cookies() {
|
|
||||||
w.Header().Add("Set-Cookie", cookie.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
|
tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/containous/traefik/testhelpers"
|
"github.com/containous/traefik/testhelpers"
|
||||||
"github.com/containous/traefik/types"
|
"github.com/containous/traefik/types"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/urfave/negroni"
|
"github.com/urfave/negroni"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -110,7 +111,6 @@ func TestForwardAuthRedirect(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
|
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
|
||||||
|
|
||||||
location, err := res.Location()
|
location, err := res.Location()
|
||||||
|
|
||||||
assert.NoError(t, err, "there should be no error")
|
assert.NoError(t, err, "there should be no error")
|
||||||
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
|
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
|
||||||
|
|
||||||
|
@ -119,10 +119,11 @@ func TestForwardAuthRedirect(t *testing.T) {
|
||||||
assert.NotEmpty(t, string(body), "there should be something in the body")
|
assert.NotEmpty(t, string(body), "there should be something in the body")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForwardAuthCookie(t *testing.T) {
|
func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
||||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
|
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
|
w.Header().Add("X-Foo", "bar")
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
}))
|
}))
|
||||||
defer authTs.Close()
|
defer authTs.Close()
|
||||||
|
@ -142,23 +143,36 @@ func TestForwardAuthCookie(t *testing.T) {
|
||||||
ts := httptest.NewServer(n)
|
ts := httptest.NewServer(n)
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||||
|
client := &http.Client{}
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
assert.NoError(t, err, "there should be no error")
|
assert.NoError(t, err, "there should be no error")
|
||||||
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
||||||
|
|
||||||
|
require.Len(t, res.Cookies(), 1)
|
||||||
for _, cookie := range res.Cookies() {
|
for _, cookie := range res.Cookies() {
|
||||||
assert.Equal(t, "testing", cookie.Value, "they should be equal")
|
assert.Equal(t, "testing", cookie.Value, "they should be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expectedHeaders := http.Header{
|
||||||
|
"Content-Length": []string{"10"},
|
||||||
|
"Content-Type": []string{"text/plain; charset=utf-8"},
|
||||||
|
"X-Foo": []string{"bar"},
|
||||||
|
"Set-Cookie": []string{"example=testing; Path=/"},
|
||||||
|
"X-Content-Type-Options": []string{"nosniff"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, res.Header, 6)
|
||||||
|
for key, value := range expectedHeaders {
|
||||||
|
assert.Equal(t, value, res.Header[key])
|
||||||
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(res.Body)
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
assert.NoError(t, err, "there should be no error")
|
assert.NoError(t, err, "there should be no error")
|
||||||
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_writeHeader(t *testing.T) {
|
func Test_writeHeader(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
|
|
Loading…
Add table
Reference in a new issue