From 2975acdc8217a3654c344f6d5dfdfedd51610e81 Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Mon, 23 Apr 2018 15:28:04 +0200 Subject: [PATCH] Forward auth: copy response headers when auth failed. --- middlewares/auth/forward.go | 12 ++++++------ middlewares/auth/forward_test.go | 22 ++++++++++++++++++---- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/middlewares/auth/forward.go b/middlewares/auth/forward.go index e4eea0976..3cdb85d5b 100644 --- a/middlewares/auth/forward.go +++ b/middlewares/auth/forward.go @@ -25,6 +25,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next return http.ErrUseLastResponse }, } + if config.TLS != nil { tlsConfig, err := config.TLS.CreateTLSConfig() if err != nil { @@ -32,10 +33,12 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next w.WriteHeader(http.StatusInternalServerError) return } + httpClient.Transport = &http.Transport{ TLSClientConfig: tlsConfig, } } + forwardReq, err := http.NewRequest(http.MethodGet, config.Address, nil) tracing.LogRequest(tracing.GetSpan(r), forwardReq) if err != nil { @@ -68,6 +71,8 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices { log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode) + utils.CopyHeaders(w.Header(), forwardResponse.Header) + // Grab the location header, if any. redirectURL, err := forwardResponse.Location() @@ -79,12 +84,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next } } else if redirectURL.String() != "" { // Set the location in our response if one was sent back. - w.Header().Add("Location", redirectURL.String()) - } - - // Pass any Set-Cookie headers the forward auth server provides - for _, cookie := range forwardResponse.Cookies() { - w.Header().Add("Set-Cookie", cookie.String()) + w.Header().Set("Location", redirectURL.String()) } tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode) diff --git a/middlewares/auth/forward_test.go b/middlewares/auth/forward_test.go index 05ffdba84..a52014420 100644 --- a/middlewares/auth/forward_test.go +++ b/middlewares/auth/forward_test.go @@ -11,6 +11,7 @@ import ( "github.com/containous/traefik/testhelpers" "github.com/containous/traefik/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/urfave/negroni" ) @@ -110,7 +111,6 @@ func TestForwardAuthRedirect(t *testing.T) { assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal") location, err := res.Location() - assert.NoError(t, err, "there should be no error") assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal") @@ -119,10 +119,11 @@ func TestForwardAuthRedirect(t *testing.T) { assert.NotEmpty(t, string(body), "there should be something in the body") } -func TestForwardAuthCookie(t *testing.T) { +func TestForwardAuthFailResponseHeaders(t *testing.T) { authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"} http.SetCookie(w, cookie) + w.Header().Add("X-Foo", "bar") http.Error(w, "Forbidden", http.StatusForbidden) })) defer authTs.Close() @@ -142,23 +143,36 @@ func TestForwardAuthCookie(t *testing.T) { ts := httptest.NewServer(n) defer ts.Close() - client := &http.Client{} req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + client := &http.Client{} res, err := client.Do(req) assert.NoError(t, err, "there should be no error") assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal") + require.Len(t, res.Cookies(), 1) for _, cookie := range res.Cookies() { assert.Equal(t, "testing", cookie.Value, "they should be equal") } + expectedHeaders := http.Header{ + "Content-Length": []string{"10"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + "X-Foo": []string{"bar"}, + "Set-Cookie": []string{"example=testing; Path=/"}, + "X-Content-Type-Options": []string{"nosniff"}, + } + + assert.Len(t, res.Header, 6) + for key, value := range expectedHeaders { + assert.Equal(t, value, res.Header[key]) + } + body, err := ioutil.ReadAll(res.Body) assert.NoError(t, err, "there should be no error") assert.Equal(t, "Forbidden\n", string(body), "they should be equal") } func Test_writeHeader(t *testing.T) { - testCases := []struct { name string headers map[string]string