Properly add response headers for CORS
This commit is contained in:
parent
74c5ec70a9
commit
3f6ea04048
8 changed files with 198 additions and 138 deletions
|
@ -8,18 +8,3 @@
|
||||||
[entryPoints]
|
[entryPoints]
|
||||||
[entryPoints.web]
|
[entryPoints.web]
|
||||||
address = ":8000"
|
address = ":8000"
|
||||||
|
|
||||||
[providers]
|
|
||||||
[providers.file]
|
|
||||||
|
|
||||||
## dynamic configuration ##
|
|
||||||
|
|
||||||
[http.routers]
|
|
||||||
[http.routers.router1]
|
|
||||||
rule = "Host(`test.localhost`)"
|
|
||||||
service = "service1"
|
|
||||||
|
|
||||||
[http.services]
|
|
||||||
[http.services.service1.loadBalancer]
|
|
||||||
[[http.services.service1.loadBalancer.servers]]
|
|
||||||
url = "http://172.17.0.2:80"
|
|
||||||
|
|
|
@ -17,6 +17,12 @@
|
||||||
[http.routers]
|
[http.routers]
|
||||||
[http.routers.router1]
|
[http.routers.router1]
|
||||||
rule = "Host(`test.localhost`)"
|
rule = "Host(`test.localhost`)"
|
||||||
|
middlewares = ["cors"]
|
||||||
|
service = "service1"
|
||||||
|
|
||||||
|
[http.routers.router2]
|
||||||
|
rule = "Host(`test2.localhost`)"
|
||||||
|
middlewares = ["nocors"]
|
||||||
service = "service1"
|
service = "service1"
|
||||||
|
|
||||||
[http.middlewares]
|
[http.middlewares]
|
||||||
|
@ -26,7 +32,11 @@
|
||||||
accessControlMaxAge = 100
|
accessControlMaxAge = 100
|
||||||
addVaryHeader = true
|
addVaryHeader = true
|
||||||
|
|
||||||
|
[http.middlewares.nocors.Headers]
|
||||||
|
[http.middlewares.nocors.Headers.CustomResponseHeaders]
|
||||||
|
X-Custom-Response-Header = "True"
|
||||||
|
|
||||||
[http.services]
|
[http.services]
|
||||||
[http.services.service1.loadBalancer]
|
[http.services.service1.loadBalancer]
|
||||||
[[http.services.service1.loadBalancer.servers]]
|
[[http.services.service1.loadBalancer.servers]]
|
||||||
url = "http://172.17.0.2:80"
|
url = "http://127.0.0.1:9000"
|
||||||
|
|
|
@ -12,12 +12,6 @@ import (
|
||||||
// Headers test suites
|
// Headers test suites
|
||||||
type HeadersSuite struct{ BaseSuite }
|
type HeadersSuite struct{ BaseSuite }
|
||||||
|
|
||||||
func (s *HeadersSuite) SetUpSuite(c *check.C) {
|
|
||||||
s.createComposeProject(c, "headers")
|
|
||||||
|
|
||||||
s.composeProject.Start(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *HeadersSuite) TestSimpleConfiguration(c *check.C) {
|
func (s *HeadersSuite) TestSimpleConfiguration(c *check.C) {
|
||||||
cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/basic.toml"))
|
cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/basic.toml"))
|
||||||
defer display(c)
|
defer display(c)
|
||||||
|
@ -38,10 +32,18 @@ func (s *HeadersSuite) TestCorsResponses(c *check.C) {
|
||||||
c.Assert(err, checker.IsNil)
|
c.Assert(err, checker.IsNil)
|
||||||
defer cmd.Process.Kill()
|
defer cmd.Process.Kill()
|
||||||
|
|
||||||
|
backend := startTestServer("9000", http.StatusOK)
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
err = try.GetRequest(backend.URL, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK))
|
||||||
|
c.Assert(err, checker.IsNil)
|
||||||
|
|
||||||
testCase := []struct {
|
testCase := []struct {
|
||||||
desc string
|
desc string
|
||||||
requestHeaders http.Header
|
requestHeaders http.Header
|
||||||
expected http.Header
|
expected http.Header
|
||||||
|
reqHost string
|
||||||
|
method string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "simple access control allow origin",
|
desc: "simple access control allow origin",
|
||||||
|
@ -52,33 +54,9 @@ func (s *HeadersSuite) TestCorsResponses(c *check.C) {
|
||||||
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
"Vary": {"Origin"},
|
"Vary": {"Origin"},
|
||||||
},
|
},
|
||||||
|
reqHost: "test.localhost",
|
||||||
|
method: http.MethodGet,
|
||||||
},
|
},
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCase {
|
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/", nil)
|
|
||||||
c.Assert(err, checker.IsNil)
|
|
||||||
req.Host = "test.localhost"
|
|
||||||
req.Header = test.requestHeaders
|
|
||||||
|
|
||||||
err = try.Request(req, 500*time.Millisecond, try.HasBody(), try.HasHeaderStruct(test.expected))
|
|
||||||
c.Assert(err, checker.IsNil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *HeadersSuite) TestCorsPreflightResponses(c *check.C) {
|
|
||||||
cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/cors.toml"))
|
|
||||||
defer display(c)
|
|
||||||
|
|
||||||
err := cmd.Start()
|
|
||||||
c.Assert(err, checker.IsNil)
|
|
||||||
defer cmd.Process.Kill()
|
|
||||||
|
|
||||||
testCase := []struct {
|
|
||||||
desc string
|
|
||||||
requestHeaders http.Header
|
|
||||||
expected http.Header
|
|
||||||
}{
|
|
||||||
{
|
{
|
||||||
desc: "simple preflight request",
|
desc: "simple preflight request",
|
||||||
requestHeaders: http.Header{
|
requestHeaders: http.Header{
|
||||||
|
@ -91,16 +69,44 @@ func (s *HeadersSuite) TestCorsPreflightResponses(c *check.C) {
|
||||||
"Access-Control-Max-Age": {"100"},
|
"Access-Control-Max-Age": {"100"},
|
||||||
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
||||||
},
|
},
|
||||||
|
reqHost: "test.localhost",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "preflight Options request with no cors configured",
|
||||||
|
requestHeaders: http.Header{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: http.Header{
|
||||||
|
"X-Custom-Response-Header": {"True"},
|
||||||
|
},
|
||||||
|
reqHost: "test2.localhost",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "preflight Get request with no cors configured",
|
||||||
|
requestHeaders: http.Header{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: http.Header{
|
||||||
|
"X-Custom-Response-Header": {"True"},
|
||||||
|
},
|
||||||
|
reqHost: "test2.localhost",
|
||||||
|
method: http.MethodGet,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range testCase {
|
for _, test := range testCase {
|
||||||
req, err := http.NewRequest(http.MethodOptions, "http://127.0.0.1:8000/", nil)
|
req, err := http.NewRequest(test.method, "http://127.0.0.1:8000/", nil)
|
||||||
c.Assert(err, checker.IsNil)
|
c.Assert(err, checker.IsNil)
|
||||||
req.Host = "test.localhost"
|
req.Host = test.reqHost
|
||||||
req.Header = test.requestHeaders
|
req.Header = test.requestHeaders
|
||||||
|
|
||||||
err = try.Request(req, 500*time.Millisecond, try.HasBody(), try.HasHeaderStruct(test.expected))
|
err = try.Request(req, 500*time.Millisecond, try.HasHeaderStruct(test.expected))
|
||||||
c.Assert(err, checker.IsNil)
|
c.Assert(err, checker.IsNil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
whoami1:
|
|
||||||
image: containous/whoami
|
|
||||||
ports:
|
|
||||||
- "8881:80"
|
|
|
@ -168,18 +168,17 @@ func HasHeaderValue(header, value string, exactMatch bool) ResponseCondition {
|
||||||
func HasHeaderStruct(header http.Header) ResponseCondition {
|
func HasHeaderStruct(header http.Header) ResponseCondition {
|
||||||
return func(res *http.Response) error {
|
return func(res *http.Response) error {
|
||||||
for key := range header {
|
for key := range header {
|
||||||
if _, ok := res.Header[key]; ok {
|
if _, ok := res.Header[key]; !ok {
|
||||||
// Header exists in the response, test it.
|
return fmt.Errorf("header %s not present in the response. Expected headers: %v Got response headers: %v", key, header, res.Header)
|
||||||
eq := reflect.DeepEqual(header[key], res.Header[key])
|
}
|
||||||
if !eq {
|
|
||||||
return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key])
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Header exists in the response, test it.
|
||||||
|
if !reflect.DeepEqual(header[key], res.Header[key]) {
|
||||||
|
return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoCondition is a retry condition function.
|
// DoCondition is a retry condition function.
|
||||||
|
|
|
@ -16,8 +16,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
typeName = "Headers"
|
typeName = "Headers"
|
||||||
originHeaderKey = "X-Request-Origin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type headers struct {
|
type headers struct {
|
||||||
|
@ -107,29 +106,127 @@ func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
|
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be
|
// Header is a middleware that helps setup a few basic security features.
|
||||||
// provided to configure which features should be enabled, and the ability to override a few of the default values.
|
// A single headerOptions struct can be provided to configure which features should be enabled,
|
||||||
|
// and the ability to override a few of the default values.
|
||||||
type Header struct {
|
type Header struct {
|
||||||
next http.Handler
|
next http.Handler
|
||||||
headers *dynamic.Headers
|
hasCustomHeaders bool
|
||||||
|
hasCorsHeaders bool
|
||||||
|
headers *dynamic.Headers
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeader constructs a new header instance from supplied frontend header struct.
|
// NewHeader constructs a new header instance from supplied frontend header struct.
|
||||||
func NewHeader(next http.Handler, headers dynamic.Headers) *Header {
|
func NewHeader(next http.Handler, headers dynamic.Headers) *Header {
|
||||||
|
hasCustomHeaders := headers.HasCustomHeadersDefined()
|
||||||
|
hasCorsHeaders := headers.HasCorsHeadersDefined()
|
||||||
|
|
||||||
return &Header{
|
return &Header{
|
||||||
next: next,
|
next: next,
|
||||||
headers: &headers,
|
headers: &headers,
|
||||||
|
hasCustomHeaders: hasCustomHeaders,
|
||||||
|
hasCorsHeaders: hasCorsHeaders,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// Handle Cors headers and preflight if configured.
|
||||||
|
if isPreflight := s.processCorsHeaders(rw, req); isPreflight {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.hasCustomHeaders {
|
||||||
|
s.modifyCustomRequestHeaders(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is a next, call it.
|
||||||
|
if s.next != nil {
|
||||||
|
s.next.ServeHTTP(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modifyCustomRequestHeaders sets or deletes custom request headers.
|
||||||
|
func (s *Header) modifyCustomRequestHeaders(req *http.Request) {
|
||||||
|
// Loop through Custom request headers
|
||||||
|
for header, value := range s.headers.CustomRequestHeaders {
|
||||||
|
if value == "" {
|
||||||
|
req.Header.Del(header)
|
||||||
|
} else {
|
||||||
|
req.Header.Set(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preRequestModifyCorsResponseHeaders sets during request processing time,
|
||||||
|
// all the CORS response headers that we already know that are supposed to be set,
|
||||||
|
// and which do not depend on a later state of the response.
|
||||||
|
// One notable example of a header that can only be modified later on is "Vary",
|
||||||
|
// And this is set in the post-response response modifier method
|
||||||
|
func (s *Header) preRequestModifyCorsResponseHeaders(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
originHeader := req.Header.Get("Origin")
|
||||||
|
allowOrigin := s.getAllowOrigin(originHeader)
|
||||||
|
|
||||||
|
if allowOrigin != "" {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.headers.AccessControlAllowCredentials {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.headers.AccessControlExposeHeaders) > 0 {
|
||||||
|
exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
|
||||||
|
rw.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostRequestModifyResponseHeaders set or delete response headers.
|
||||||
|
// This method is called AFTER the response is generated from the backend
|
||||||
|
// and can merge/override headers from the backend response.
|
||||||
|
func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error {
|
||||||
|
// Loop through Custom response headers
|
||||||
|
for header, value := range s.headers.CustomResponseHeaders {
|
||||||
|
if value == "" {
|
||||||
|
res.Header.Del(header)
|
||||||
|
} else {
|
||||||
|
res.Header.Set(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !s.headers.AddVaryHeader {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
varyHeader := res.Header.Get("Vary")
|
||||||
|
if varyHeader == "Origin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if varyHeader != "" {
|
||||||
|
varyHeader += ","
|
||||||
|
}
|
||||||
|
varyHeader += "Origin"
|
||||||
|
|
||||||
|
res.Header.Set("Vary", varyHeader)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processCorsHeaders processes the incoming request,
|
||||||
|
// and returns if it is a preflight request.
|
||||||
|
// If not a preflight, it handles the preRequestModifyCorsResponseHeaders.
|
||||||
|
func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool {
|
||||||
|
if !s.hasCorsHeaders {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
reqAcMethod := req.Header.Get("Access-Control-Request-Method")
|
reqAcMethod := req.Header.Get("Access-Control-Request-Method")
|
||||||
reqAcHeaders := req.Header.Get("Access-Control-Request-Headers")
|
reqAcHeaders := req.Header.Get("Access-Control-Request-Headers")
|
||||||
originHeader := req.Header.Get("Origin")
|
originHeader := req.Header.Get("Origin")
|
||||||
|
|
||||||
if reqAcMethod != "" && reqAcHeaders != "" && originHeader != "" && req.Method == http.MethodOptions {
|
if reqAcMethod != "" && reqAcHeaders != "" && originHeader != "" && req.Method == http.MethodOptions {
|
||||||
// If the request is an OPTIONS request with an Access-Control-Request-Method header, and Access-Control-Request-Headers headers,
|
// If the request is an OPTIONS request with an Access-Control-Request-Method header,
|
||||||
// and Origin headers, then it is a CORS preflight request, and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request
|
// and Access-Control-Request-Headers headers, and Origin headers,
|
||||||
|
// then it is a CORS preflight request,
|
||||||
|
// and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request
|
||||||
if s.headers.AccessControlAllowCredentials {
|
if s.headers.AccessControlAllowCredentials {
|
||||||
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
}
|
}
|
||||||
|
@ -151,71 +248,11 @@ func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge)))
|
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge)))
|
||||||
|
return true
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(originHeader) > 0 {
|
s.preRequestModifyCorsResponseHeaders(rw, req)
|
||||||
rw.Header().Set(originHeaderKey, originHeader)
|
return false
|
||||||
}
|
|
||||||
|
|
||||||
s.modifyRequestHeaders(req)
|
|
||||||
// If there is a next, call it.
|
|
||||||
if s.next != nil {
|
|
||||||
s.next.ServeHTTP(rw, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// modifyRequestHeaders sets or deletes request headers.
|
|
||||||
func (s *Header) modifyRequestHeaders(req *http.Request) {
|
|
||||||
// Loop through Custom request headers
|
|
||||||
for header, value := range s.headers.CustomRequestHeaders {
|
|
||||||
if value == "" {
|
|
||||||
req.Header.Del(header)
|
|
||||||
} else {
|
|
||||||
req.Header.Set(header, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModifyResponseHeaders set or delete response headers
|
|
||||||
func (s *Header) ModifyResponseHeaders(res *http.Response) error {
|
|
||||||
// Loop through Custom response headers
|
|
||||||
for header, value := range s.headers.CustomResponseHeaders {
|
|
||||||
if value == "" {
|
|
||||||
res.Header.Del(header)
|
|
||||||
} else {
|
|
||||||
res.Header.Set(header, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
originHeader := res.Header.Get(originHeaderKey)
|
|
||||||
allowOrigin := s.getAllowOrigin(originHeader)
|
|
||||||
// Delete the origin header key, since it is only used to pass data from the request for response handling
|
|
||||||
res.Header.Del(originHeaderKey)
|
|
||||||
if allowOrigin != "" {
|
|
||||||
res.Header.Set("Access-Control-Allow-Origin", allowOrigin)
|
|
||||||
|
|
||||||
if s.headers.AddVaryHeader {
|
|
||||||
varyHeader := res.Header.Get("Vary")
|
|
||||||
if varyHeader != "" {
|
|
||||||
varyHeader += ","
|
|
||||||
}
|
|
||||||
varyHeader += "Origin"
|
|
||||||
|
|
||||||
res.Header.Set("Vary", varyHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.headers.AccessControlAllowCredentials {
|
|
||||||
res.Header.Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
}
|
|
||||||
|
|
||||||
exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
|
|
||||||
if exposeHeaders != "" {
|
|
||||||
res.Header.Set("Access-Control-Expose-Headers", exposeHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Header) getAllowOrigin(header string) string {
|
func (s *Header) getAllowOrigin(header string) string {
|
||||||
|
|
|
@ -333,6 +333,7 @@ func TestGetTracingInformation(t *testing.T) {
|
||||||
func TestCORSResponses(t *testing.T) {
|
func TestCORSResponses(t *testing.T) {
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") })
|
nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") })
|
||||||
|
existingOriginHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Origin") })
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
@ -436,6 +437,32 @@ func TestCORSResponses(t *testing.T) {
|
||||||
"Vary": {"Testing,Origin"},
|
"Vary": {"Testing,Origin"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request with Vary Headers and existing vary:origin response",
|
||||||
|
header: NewHeader(existingOriginHandler, dynamic.Headers{
|
||||||
|
AccessControlAllowOrigin: "origin-list-or-null",
|
||||||
|
AddVaryHeader: true,
|
||||||
|
}),
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
|
"Vary": {"Origin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS",
|
||||||
|
header: NewHeader(emptyHandler, dynamic.Headers{
|
||||||
|
CustomRequestHeaders: map[string]string{"foo": "bar"},
|
||||||
|
}),
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range testCases {
|
for _, test := range testCases {
|
||||||
|
@ -445,7 +472,7 @@ func TestCORSResponses(t *testing.T) {
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
test.header.ServeHTTP(rw, req)
|
test.header.ServeHTTP(rw, req)
|
||||||
err := test.header.ModifyResponseHeaders(rw.Result())
|
err := test.header.PostRequestModifyResponseHeaders(rw.Result())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, test.expected, rw.Result().Header)
|
assert.Equal(t, test.expected, rw.Result().Header)
|
||||||
})
|
})
|
||||||
|
@ -492,7 +519,7 @@ func TestCustomResponseHeaders(t *testing.T) {
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
test.header.ServeHTTP(rw, req)
|
test.header.ServeHTTP(rw, req)
|
||||||
err := test.header.ModifyResponseHeaders(rw.Result())
|
err := test.header.PostRequestModifyResponseHeaders(rw.Result())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, test.expected, rw.Result().Header)
|
assert.Equal(t, test.expected, rw.Result().Header)
|
||||||
})
|
})
|
||||||
|
|
|
@ -34,7 +34,7 @@ func buildHeaders(hdrs *dynamic.Headers) func(*http.Response) error {
|
||||||
|
|
||||||
return func(resp *http.Response) error {
|
return func(resp *http.Response) error {
|
||||||
if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() {
|
if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() {
|
||||||
err := headers.NewHeader(nil, *hdrs).ModifyResponseHeaders(resp)
|
err := headers.NewHeader(nil, *hdrs).PostRequestModifyResponseHeaders(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue