Properly add response headers for CORS

This commit is contained in:
Daniel Tomcej 2019-07-12 03:46:04 -06:00 committed by Traefiker Bot
parent 74c5ec70a9
commit 3f6ea04048
8 changed files with 198 additions and 138 deletions

View file

@ -8,18 +8,3 @@
[entryPoints] [entryPoints]
[entryPoints.web] [entryPoints.web]
address = ":8000" address = ":8000"
[providers]
[providers.file]
## dynamic configuration ##
[http.routers]
[http.routers.router1]
rule = "Host(`test.localhost`)"
service = "service1"
[http.services]
[http.services.service1.loadBalancer]
[[http.services.service1.loadBalancer.servers]]
url = "http://172.17.0.2:80"

View file

@ -17,6 +17,12 @@
[http.routers] [http.routers]
[http.routers.router1] [http.routers.router1]
rule = "Host(`test.localhost`)" rule = "Host(`test.localhost`)"
middlewares = ["cors"]
service = "service1"
[http.routers.router2]
rule = "Host(`test2.localhost`)"
middlewares = ["nocors"]
service = "service1" service = "service1"
[http.middlewares] [http.middlewares]
@ -26,7 +32,11 @@
accessControlMaxAge = 100 accessControlMaxAge = 100
addVaryHeader = true addVaryHeader = true
[http.middlewares.nocors.Headers]
[http.middlewares.nocors.Headers.CustomResponseHeaders]
X-Custom-Response-Header = "True"
[http.services] [http.services]
[http.services.service1.loadBalancer] [http.services.service1.loadBalancer]
[[http.services.service1.loadBalancer.servers]] [[http.services.service1.loadBalancer.servers]]
url = "http://172.17.0.2:80" url = "http://127.0.0.1:9000"

View file

@ -12,12 +12,6 @@ import (
// Headers test suites // Headers test suites
type HeadersSuite struct{ BaseSuite } type HeadersSuite struct{ BaseSuite }
func (s *HeadersSuite) SetUpSuite(c *check.C) {
s.createComposeProject(c, "headers")
s.composeProject.Start(c)
}
func (s *HeadersSuite) TestSimpleConfiguration(c *check.C) { func (s *HeadersSuite) TestSimpleConfiguration(c *check.C) {
cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/basic.toml")) cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/basic.toml"))
defer display(c) defer display(c)
@ -38,10 +32,18 @@ func (s *HeadersSuite) TestCorsResponses(c *check.C) {
c.Assert(err, checker.IsNil) c.Assert(err, checker.IsNil)
defer cmd.Process.Kill() defer cmd.Process.Kill()
backend := startTestServer("9000", http.StatusOK)
defer backend.Close()
err = try.GetRequest(backend.URL, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK))
c.Assert(err, checker.IsNil)
testCase := []struct { testCase := []struct {
desc string desc string
requestHeaders http.Header requestHeaders http.Header
expected http.Header expected http.Header
reqHost string
method string
}{ }{
{ {
desc: "simple access control allow origin", desc: "simple access control allow origin",
@ -52,33 +54,9 @@ func (s *HeadersSuite) TestCorsResponses(c *check.C) {
"Access-Control-Allow-Origin": {"https://foo.bar.org"}, "Access-Control-Allow-Origin": {"https://foo.bar.org"},
"Vary": {"Origin"}, "Vary": {"Origin"},
}, },
reqHost: "test.localhost",
method: http.MethodGet,
}, },
}
for _, test := range testCase {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/", nil)
c.Assert(err, checker.IsNil)
req.Host = "test.localhost"
req.Header = test.requestHeaders
err = try.Request(req, 500*time.Millisecond, try.HasBody(), try.HasHeaderStruct(test.expected))
c.Assert(err, checker.IsNil)
}
}
func (s *HeadersSuite) TestCorsPreflightResponses(c *check.C) {
cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/cors.toml"))
defer display(c)
err := cmd.Start()
c.Assert(err, checker.IsNil)
defer cmd.Process.Kill()
testCase := []struct {
desc string
requestHeaders http.Header
expected http.Header
}{
{ {
desc: "simple preflight request", desc: "simple preflight request",
requestHeaders: http.Header{ requestHeaders: http.Header{
@ -91,16 +69,44 @@ func (s *HeadersSuite) TestCorsPreflightResponses(c *check.C) {
"Access-Control-Max-Age": {"100"}, "Access-Control-Max-Age": {"100"},
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
}, },
reqHost: "test.localhost",
method: http.MethodOptions,
},
{
desc: "preflight Options request with no cors configured",
requestHeaders: http.Header{
"Access-Control-Request-Headers": {"origin"},
"Access-Control-Request-Method": {"GET", "OPTIONS"},
"Origin": {"https://foo.bar.org"},
},
expected: http.Header{
"X-Custom-Response-Header": {"True"},
},
reqHost: "test2.localhost",
method: http.MethodOptions,
},
{
desc: "preflight Get request with no cors configured",
requestHeaders: http.Header{
"Access-Control-Request-Headers": {"origin"},
"Access-Control-Request-Method": {"GET", "OPTIONS"},
"Origin": {"https://foo.bar.org"},
},
expected: http.Header{
"X-Custom-Response-Header": {"True"},
},
reqHost: "test2.localhost",
method: http.MethodGet,
}, },
} }
for _, test := range testCase { for _, test := range testCase {
req, err := http.NewRequest(http.MethodOptions, "http://127.0.0.1:8000/", nil) req, err := http.NewRequest(test.method, "http://127.0.0.1:8000/", nil)
c.Assert(err, checker.IsNil) c.Assert(err, checker.IsNil)
req.Host = "test.localhost" req.Host = test.reqHost
req.Header = test.requestHeaders req.Header = test.requestHeaders
err = try.Request(req, 500*time.Millisecond, try.HasBody(), try.HasHeaderStruct(test.expected)) err = try.Request(req, 500*time.Millisecond, try.HasHeaderStruct(test.expected))
c.Assert(err, checker.IsNil) c.Assert(err, checker.IsNil)
} }
} }

View file

@ -1,4 +0,0 @@
whoami1:
image: containous/whoami
ports:
- "8881:80"

View file

@ -168,18 +168,17 @@ func HasHeaderValue(header, value string, exactMatch bool) ResponseCondition {
func HasHeaderStruct(header http.Header) ResponseCondition { func HasHeaderStruct(header http.Header) ResponseCondition {
return func(res *http.Response) error { return func(res *http.Response) error {
for key := range header { for key := range header {
if _, ok := res.Header[key]; ok { if _, ok := res.Header[key]; !ok {
// Header exists in the response, test it. return fmt.Errorf("header %s not present in the response. Expected headers: %v Got response headers: %v", key, header, res.Header)
eq := reflect.DeepEqual(header[key], res.Header[key]) }
if !eq {
return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key])
}
// Header exists in the response, test it.
if !reflect.DeepEqual(header[key], res.Header[key]) {
return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key])
} }
} }
return nil return nil
} }
} }
// DoCondition is a retry condition function. // DoCondition is a retry condition function.

View file

@ -16,8 +16,7 @@ import (
) )
const ( const (
typeName = "Headers" typeName = "Headers"
originHeaderKey = "X-Request-Origin"
) )
type headers struct { type headers struct {
@ -107,29 +106,127 @@ func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP) s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
} }
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be // Header is a middleware that helps setup a few basic security features.
// provided to configure which features should be enabled, and the ability to override a few of the default values. // A single headerOptions struct can be provided to configure which features should be enabled,
// and the ability to override a few of the default values.
type Header struct { type Header struct {
next http.Handler next http.Handler
headers *dynamic.Headers hasCustomHeaders bool
hasCorsHeaders bool
headers *dynamic.Headers
} }
// NewHeader constructs a new header instance from supplied frontend header struct. // NewHeader constructs a new header instance from supplied frontend header struct.
func NewHeader(next http.Handler, headers dynamic.Headers) *Header { func NewHeader(next http.Handler, headers dynamic.Headers) *Header {
hasCustomHeaders := headers.HasCustomHeadersDefined()
hasCorsHeaders := headers.HasCorsHeadersDefined()
return &Header{ return &Header{
next: next, next: next,
headers: &headers, headers: &headers,
hasCustomHeaders: hasCustomHeaders,
hasCorsHeaders: hasCorsHeaders,
} }
} }
func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Handle Cors headers and preflight if configured.
if isPreflight := s.processCorsHeaders(rw, req); isPreflight {
return
}
if s.hasCustomHeaders {
s.modifyCustomRequestHeaders(req)
}
// If there is a next, call it.
if s.next != nil {
s.next.ServeHTTP(rw, req)
}
}
// modifyCustomRequestHeaders sets or deletes custom request headers.
func (s *Header) modifyCustomRequestHeaders(req *http.Request) {
// Loop through Custom request headers
for header, value := range s.headers.CustomRequestHeaders {
if value == "" {
req.Header.Del(header)
} else {
req.Header.Set(header, value)
}
}
}
// preRequestModifyCorsResponseHeaders sets during request processing time,
// all the CORS response headers that we already know that are supposed to be set,
// and which do not depend on a later state of the response.
// One notable example of a header that can only be modified later on is "Vary",
// And this is set in the post-response response modifier method
func (s *Header) preRequestModifyCorsResponseHeaders(rw http.ResponseWriter, req *http.Request) {
originHeader := req.Header.Get("Origin")
allowOrigin := s.getAllowOrigin(originHeader)
if allowOrigin != "" {
rw.Header().Set("Access-Control-Allow-Origin", allowOrigin)
}
if s.headers.AccessControlAllowCredentials {
rw.Header().Set("Access-Control-Allow-Credentials", "true")
}
if len(s.headers.AccessControlExposeHeaders) > 0 {
exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
rw.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
}
}
// PostRequestModifyResponseHeaders set or delete response headers.
// This method is called AFTER the response is generated from the backend
// and can merge/override headers from the backend response.
func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error {
// Loop through Custom response headers
for header, value := range s.headers.CustomResponseHeaders {
if value == "" {
res.Header.Del(header)
} else {
res.Header.Set(header, value)
}
}
if !s.headers.AddVaryHeader {
return nil
}
varyHeader := res.Header.Get("Vary")
if varyHeader == "Origin" {
return nil
}
if varyHeader != "" {
varyHeader += ","
}
varyHeader += "Origin"
res.Header.Set("Vary", varyHeader)
return nil
}
// processCorsHeaders processes the incoming request,
// and returns if it is a preflight request.
// If not a preflight, it handles the preRequestModifyCorsResponseHeaders.
func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool {
if !s.hasCorsHeaders {
return false
}
reqAcMethod := req.Header.Get("Access-Control-Request-Method") reqAcMethod := req.Header.Get("Access-Control-Request-Method")
reqAcHeaders := req.Header.Get("Access-Control-Request-Headers") reqAcHeaders := req.Header.Get("Access-Control-Request-Headers")
originHeader := req.Header.Get("Origin") originHeader := req.Header.Get("Origin")
if reqAcMethod != "" && reqAcHeaders != "" && originHeader != "" && req.Method == http.MethodOptions { if reqAcMethod != "" && reqAcHeaders != "" && originHeader != "" && req.Method == http.MethodOptions {
// If the request is an OPTIONS request with an Access-Control-Request-Method header, and Access-Control-Request-Headers headers, // If the request is an OPTIONS request with an Access-Control-Request-Method header,
// and Origin headers, then it is a CORS preflight request, and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request // and Access-Control-Request-Headers headers, and Origin headers,
// then it is a CORS preflight request,
// and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request
if s.headers.AccessControlAllowCredentials { if s.headers.AccessControlAllowCredentials {
rw.Header().Set("Access-Control-Allow-Credentials", "true") rw.Header().Set("Access-Control-Allow-Credentials", "true")
} }
@ -151,71 +248,11 @@ func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge))) rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge)))
return true
return
} }
if len(originHeader) > 0 { s.preRequestModifyCorsResponseHeaders(rw, req)
rw.Header().Set(originHeaderKey, originHeader) return false
}
s.modifyRequestHeaders(req)
// If there is a next, call it.
if s.next != nil {
s.next.ServeHTTP(rw, req)
}
}
// modifyRequestHeaders sets or deletes request headers.
func (s *Header) modifyRequestHeaders(req *http.Request) {
// Loop through Custom request headers
for header, value := range s.headers.CustomRequestHeaders {
if value == "" {
req.Header.Del(header)
} else {
req.Header.Set(header, value)
}
}
}
// ModifyResponseHeaders set or delete response headers
func (s *Header) ModifyResponseHeaders(res *http.Response) error {
// Loop through Custom response headers
for header, value := range s.headers.CustomResponseHeaders {
if value == "" {
res.Header.Del(header)
} else {
res.Header.Set(header, value)
}
}
originHeader := res.Header.Get(originHeaderKey)
allowOrigin := s.getAllowOrigin(originHeader)
// Delete the origin header key, since it is only used to pass data from the request for response handling
res.Header.Del(originHeaderKey)
if allowOrigin != "" {
res.Header.Set("Access-Control-Allow-Origin", allowOrigin)
if s.headers.AddVaryHeader {
varyHeader := res.Header.Get("Vary")
if varyHeader != "" {
varyHeader += ","
}
varyHeader += "Origin"
res.Header.Set("Vary", varyHeader)
}
}
if s.headers.AccessControlAllowCredentials {
res.Header.Set("Access-Control-Allow-Credentials", "true")
}
exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
if exposeHeaders != "" {
res.Header.Set("Access-Control-Expose-Headers", exposeHeaders)
}
return nil
} }
func (s *Header) getAllowOrigin(header string) string { func (s *Header) getAllowOrigin(header string) string {

View file

@ -333,6 +333,7 @@ func TestGetTracingInformation(t *testing.T) {
func TestCORSResponses(t *testing.T) { func TestCORSResponses(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") }) nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") })
existingOriginHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Origin") })
testCases := []struct { testCases := []struct {
desc string desc string
@ -436,6 +437,32 @@ func TestCORSResponses(t *testing.T) {
"Vary": {"Testing,Origin"}, "Vary": {"Testing,Origin"},
}, },
}, },
{
desc: "Test Simple Request with Vary Headers and existing vary:origin response",
header: NewHeader(existingOriginHandler, dynamic.Headers{
AccessControlAllowOrigin: "origin-list-or-null",
AddVaryHeader: true,
}),
requestHeaders: map[string][]string{
"Origin": {"https://foo.bar.org"},
},
expected: map[string][]string{
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
"Vary": {"Origin"},
},
},
{
desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS",
header: NewHeader(emptyHandler, dynamic.Headers{
CustomRequestHeaders: map[string]string{"foo": "bar"},
}),
requestHeaders: map[string][]string{
"Access-Control-Request-Headers": {"origin"},
"Access-Control-Request-Method": {"GET", "OPTIONS"},
"Origin": {"https://foo.bar.org"},
},
expected: map[string][]string{},
},
} }
for _, test := range testCases { for _, test := range testCases {
@ -445,7 +472,7 @@ func TestCORSResponses(t *testing.T) {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
test.header.ServeHTTP(rw, req) test.header.ServeHTTP(rw, req)
err := test.header.ModifyResponseHeaders(rw.Result()) err := test.header.PostRequestModifyResponseHeaders(rw.Result())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expected, rw.Result().Header) assert.Equal(t, test.expected, rw.Result().Header)
}) })
@ -492,7 +519,7 @@ func TestCustomResponseHeaders(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
test.header.ServeHTTP(rw, req) test.header.ServeHTTP(rw, req)
err := test.header.ModifyResponseHeaders(rw.Result()) err := test.header.PostRequestModifyResponseHeaders(rw.Result())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expected, rw.Result().Header) assert.Equal(t, test.expected, rw.Result().Header)
}) })

View file

@ -34,7 +34,7 @@ func buildHeaders(hdrs *dynamic.Headers) func(*http.Response) error {
return func(resp *http.Response) error { return func(resp *http.Response) error {
if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() { if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() {
err := headers.NewHeader(nil, *hdrs).ModifyResponseHeaders(resp) err := headers.NewHeader(nil, *hdrs).PostRequestModifyResponseHeaders(resp)
if err != nil { if err != nil {
return err return err
} }