Headers response modifier is directly applied by headers middleware
Co-authored-by: Ludovic Fernandez <ldez@users.noreply.github.com>
This commit is contained in:
parent
3677252e17
commit
52790d3c37
25 changed files with 1144 additions and 1240 deletions
33
integration/fixtures/headers/secure_multiple.toml
Normal file
33
integration/fixtures/headers/secure_multiple.toml
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
[global]
|
||||||
|
checkNewVersion = false
|
||||||
|
sendAnonymousUsage = false
|
||||||
|
|
||||||
|
[log]
|
||||||
|
level = "DEBUG"
|
||||||
|
|
||||||
|
[entryPoints]
|
||||||
|
[entryPoints.web]
|
||||||
|
address = ":8000"
|
||||||
|
|
||||||
|
[providers.file]
|
||||||
|
filename = "{{ .SelfFilename }}"
|
||||||
|
|
||||||
|
## dynamic configuration ##
|
||||||
|
|
||||||
|
[http.routers]
|
||||||
|
[http.routers.router1]
|
||||||
|
rule = "Host(`test.localhost`)"
|
||||||
|
middlewares = ["foo", "bar"]
|
||||||
|
service = "service1"
|
||||||
|
|
||||||
|
|
||||||
|
[http.middlewares]
|
||||||
|
[http.middlewares.foo.headers]
|
||||||
|
frameDeny = true
|
||||||
|
[http.middlewares.bar.headers]
|
||||||
|
contentTypeNosniff = true
|
||||||
|
|
||||||
|
[http.services]
|
||||||
|
[http.services.service1.loadBalancer]
|
||||||
|
[[http.services.service1.loadBalancer.servers]]
|
||||||
|
url = "http://127.0.0.1:9000"
|
|
@ -162,3 +162,44 @@ func (s *HeadersSuite) TestSecureHeadersResponses(c *check.C) {
|
||||||
c.Assert(err, checker.IsNil)
|
c.Assert(err, checker.IsNil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *HeadersSuite) TestMultipleSecureHeadersResponses(c *check.C) {
|
||||||
|
file := s.adaptFile(c, "fixtures/headers/secure_multiple.toml", struct{}{})
|
||||||
|
defer os.Remove(file)
|
||||||
|
cmd, display := s.traefikCmd(withConfigFile(file))
|
||||||
|
defer display(c)
|
||||||
|
|
||||||
|
err := cmd.Start()
|
||||||
|
c.Assert(err, checker.IsNil)
|
||||||
|
defer cmd.Process.Kill()
|
||||||
|
|
||||||
|
backend := startTestServer("9000", http.StatusOK, "")
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
err = try.GetRequest(backend.URL, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK))
|
||||||
|
c.Assert(err, checker.IsNil)
|
||||||
|
|
||||||
|
testCase := []struct {
|
||||||
|
desc string
|
||||||
|
expected http.Header
|
||||||
|
reqHost string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Feature-Policy Set",
|
||||||
|
expected: http.Header{
|
||||||
|
"X-Frame-Options": {"DENY"},
|
||||||
|
"X-Content-Type-Options": {"nosniff"},
|
||||||
|
},
|
||||||
|
reqHost: "test.localhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCase {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/", nil)
|
||||||
|
c.Assert(err, checker.IsNil)
|
||||||
|
req.Host = test.reqHost
|
||||||
|
|
||||||
|
err = try.Request(req, 500*time.Millisecond, try.HasHeaderStruct(test.expected))
|
||||||
|
c.Assert(err, checker.IsNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type serviceBuilder interface {
|
type serviceBuilder interface {
|
||||||
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
BuildHTTP(ctx context.Context, serviceName string) (http.Handler, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// customErrors is a middleware that provides the custom error pages..
|
// customErrors is a middleware that provides the custom error pages..
|
||||||
|
@ -54,7 +54,7 @@ func New(ctx context.Context, next http.Handler, config dynamic.ErrorPage, servi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
backend, err := serviceBuilder.BuildHTTP(ctx, config.Service, nil)
|
backend, err := serviceBuilder.BuildHTTP(ctx, config.Service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -150,7 +150,7 @@ type mockServiceBuilder struct {
|
||||||
handler http.Handler
|
handler http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockServiceBuilder) BuildHTTP(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
func (m *mockServiceBuilder) BuildHTTP(ctx context.Context, serviceName string) (http.Handler, error) {
|
||||||
return m.handler, nil
|
return m.handler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
170
pkg/middlewares/headers/header.go
Normal file
170
pkg/middlewares/headers/header.go
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
package headers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||||
|
"github.com/containous/traefik/v2/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Header is a middleware that helps setup a few basic security features.
|
||||||
|
// A single headerOptions struct can be provided to configure which features should be enabled,
|
||||||
|
// and the ability to override a few of the default values.
|
||||||
|
type Header struct {
|
||||||
|
next http.Handler
|
||||||
|
hasCustomHeaders bool
|
||||||
|
hasCorsHeaders bool
|
||||||
|
headers *dynamic.Headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHeader constructs a new header instance from supplied frontend header struct.
|
||||||
|
func NewHeader(next http.Handler, cfg dynamic.Headers) *Header {
|
||||||
|
hasCustomHeaders := cfg.HasCustomHeadersDefined()
|
||||||
|
hasCorsHeaders := cfg.HasCorsHeadersDefined()
|
||||||
|
|
||||||
|
ctx := log.With(context.Background(), log.Str(log.MiddlewareType, typeName))
|
||||||
|
handleDeprecation(ctx, &cfg)
|
||||||
|
|
||||||
|
return &Header{
|
||||||
|
next: next,
|
||||||
|
headers: &cfg,
|
||||||
|
hasCustomHeaders: hasCustomHeaders,
|
||||||
|
hasCorsHeaders: hasCorsHeaders,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// Handle Cors headers and preflight if configured.
|
||||||
|
if isPreflight := s.processCorsHeaders(rw, req); isPreflight {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.hasCustomHeaders {
|
||||||
|
s.modifyCustomRequestHeaders(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is a next, call it.
|
||||||
|
if s.next != nil {
|
||||||
|
s.next.ServeHTTP(newResponseModifier(rw, req, s.PostRequestModifyResponseHeaders), req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modifyCustomRequestHeaders sets or deletes custom request headers.
|
||||||
|
func (s *Header) modifyCustomRequestHeaders(req *http.Request) {
|
||||||
|
// Loop through Custom request headers
|
||||||
|
for header, value := range s.headers.CustomRequestHeaders {
|
||||||
|
switch {
|
||||||
|
case value == "":
|
||||||
|
req.Header.Del(header)
|
||||||
|
|
||||||
|
case strings.EqualFold(header, "Host"):
|
||||||
|
req.Host = value
|
||||||
|
|
||||||
|
default:
|
||||||
|
req.Header.Set(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostRequestModifyResponseHeaders set or delete response headers.
|
||||||
|
// This method is called AFTER the response is generated from the backend
|
||||||
|
// and can merge/override headers from the backend response.
|
||||||
|
func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error {
|
||||||
|
// Loop through Custom response headers
|
||||||
|
for header, value := range s.headers.CustomResponseHeaders {
|
||||||
|
if value == "" {
|
||||||
|
res.Header.Del(header)
|
||||||
|
} else {
|
||||||
|
res.Header.Set(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if res != nil && res.Request != nil {
|
||||||
|
originHeader := res.Request.Header.Get("Origin")
|
||||||
|
allowed, match := s.isOriginAllowed(originHeader)
|
||||||
|
|
||||||
|
if allowed {
|
||||||
|
res.Header.Set("Access-Control-Allow-Origin", match)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.headers.AccessControlAllowCredentials {
|
||||||
|
res.Header.Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.headers.AccessControlExposeHeaders) > 0 {
|
||||||
|
exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
|
||||||
|
res.Header.Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.headers.AddVaryHeader {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
varyHeader := res.Header.Get("Vary")
|
||||||
|
if varyHeader == "Origin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if varyHeader != "" {
|
||||||
|
varyHeader += ","
|
||||||
|
}
|
||||||
|
varyHeader += "Origin"
|
||||||
|
|
||||||
|
res.Header.Set("Vary", varyHeader)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processCorsHeaders processes the incoming request,
|
||||||
|
// and returns if it is a preflight request.
|
||||||
|
// If not a preflight, it handles the preRequestModifyCorsResponseHeaders.
|
||||||
|
func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool {
|
||||||
|
if !s.hasCorsHeaders {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
reqAcMethod := req.Header.Get("Access-Control-Request-Method")
|
||||||
|
originHeader := req.Header.Get("Origin")
|
||||||
|
|
||||||
|
if reqAcMethod != "" && originHeader != "" && req.Method == http.MethodOptions {
|
||||||
|
// If the request is an OPTIONS request with an Access-Control-Request-Method header,
|
||||||
|
// and Origin headers, then it is a CORS preflight request,
|
||||||
|
// and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request
|
||||||
|
if s.headers.AccessControlAllowCredentials {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowHeaders := strings.Join(s.headers.AccessControlAllowHeaders, ",")
|
||||||
|
if allowHeaders != "" {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowMethods := strings.Join(s.headers.AccessControlAllowMethods, ",")
|
||||||
|
if allowMethods != "" {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, match := s.isOriginAllowed(originHeader)
|
||||||
|
if allowed {
|
||||||
|
rw.Header().Set("Access-Control-Allow-Origin", match)
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge)))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Header) isOriginAllowed(origin string) (bool, string) {
|
||||||
|
for _, item := range s.headers.AccessControlAllowOriginList {
|
||||||
|
if item == "*" || item == origin {
|
||||||
|
return true, item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
492
pkg/middlewares/headers/header_test.go
Normal file
492
pkg/middlewares/headers/header_test.go
Normal file
|
@ -0,0 +1,492 @@
|
||||||
|
package headers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHeader_customRequestHeader(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg dynamic.Headers
|
||||||
|
expected http.Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "adds a header",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
CustomRequestHeaders: map[string]string{
|
||||||
|
"X-Custom-Request-Header": "test_request",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: http.Header{"Foo": []string{"bar"}, "X-Custom-Request-Header": []string{"test_request"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "delete a header",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
CustomRequestHeaders: map[string]string{
|
||||||
|
"X-Custom-Request-Header": "",
|
||||||
|
"Foo": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: http.Header{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "override a header",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
CustomRequestHeaders: map[string]string{
|
||||||
|
"Foo": "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: http.Header{"Foo": []string{"test"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mid := NewHeader(emptyHandler, test.cfg)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header.Set("Foo", "bar")
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rw.Code)
|
||||||
|
assert.Equal(t, test.expected, req.Header)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHeader_customRequestHeader_Host(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
customHeaders map[string]string
|
||||||
|
expectedHost string
|
||||||
|
expectedURLHost string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "standard Host header",
|
||||||
|
customHeaders: map[string]string{},
|
||||||
|
expectedHost: "example.org",
|
||||||
|
expectedURLHost: "example.org",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "custom Host header",
|
||||||
|
customHeaders: map[string]string{
|
||||||
|
"Host": "example.com",
|
||||||
|
},
|
||||||
|
expectedHost: "example.com",
|
||||||
|
expectedURLHost: "example.org",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
mid := NewHeader(emptyHandler, dynamic.Headers{CustomRequestHeaders: test.customHeaders})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.org/foo", nil)
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rw.Code)
|
||||||
|
assert.Equal(t, test.expectedHost, req.Host)
|
||||||
|
assert.Equal(t, test.expectedURLHost, req.URL.Host)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHeader_CORSPreflights(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg dynamic.Headers
|
||||||
|
requestHeaders http.Header
|
||||||
|
expected http.Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Test Simple Preflight",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
||||||
|
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
||||||
|
AccessControlMaxAge: 600,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
|
"Access-Control-Max-Age": {"600"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Wildcard origin Preflight",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
AccessControlMaxAge: 600,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Max-Age": {"600"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Allow Credentials Preflight",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
AccessControlAllowCredentials: true,
|
||||||
|
AccessControlMaxAge: 600,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Max-Age": {"600"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
||||||
|
"Access-Control-Allow-Credentials": {"true"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Allow Headers Preflight",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"},
|
||||||
|
AccessControlMaxAge: 600,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Max-Age": {"600"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
||||||
|
"Access-Control-Allow-Headers": {"origin,X-Forwarded-For"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "No Request Headers Preflight",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"},
|
||||||
|
AccessControlMaxAge: 600,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Max-Age": {"600"},
|
||||||
|
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
||||||
|
"Access-Control-Allow-Headers": {"origin,X-Forwarded-For"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
mid := NewHeader(emptyHandler, test.cfg)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/foo", nil)
|
||||||
|
req.Header = test.requestHeaders
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, rw.Result().Header)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHeader_CORSResponses(t *testing.T) {
|
||||||
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
next http.Handler
|
||||||
|
cfg dynamic.Headers
|
||||||
|
requestHeaders http.Header
|
||||||
|
expected http.Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Wildcard origin Request",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Empty origin Request",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Not Defined origin Request",
|
||||||
|
next: emptyHandler,
|
||||||
|
requestHeaders: map[string][]string{},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Allow Credentials Request",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
AccessControlAllowCredentials: true,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Allow-Credentials": {"true"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Expose Headers Request",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
AccessControlExposeHeaders: []string{"origin", "X-Forwarded-For"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
"Access-Control-Expose-Headers": {"origin,X-Forwarded-For"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request with Vary Headers",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
||||||
|
AddVaryHeader: true,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
|
"Vary": {"Origin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request with Vary Headers and non-empty response",
|
||||||
|
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// nonEmptyHandler
|
||||||
|
w.Header().Set("Vary", "Testing")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
||||||
|
AddVaryHeader: true,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
|
"Vary": {"Testing,Origin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request with Vary Headers and existing vary:origin response",
|
||||||
|
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// existingOriginHandler
|
||||||
|
w.Header().Set("Vary", "Origin")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
||||||
|
AddVaryHeader: true,
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
||||||
|
"Vary": {"Origin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request with non-empty response: set ACAO",
|
||||||
|
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// existingAccessControlAllowOriginHandlerSet
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "http://foo.bar.org")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple Request with non-empty response: add ACAO",
|
||||||
|
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// existingAccessControlAllowOriginHandlerAdd
|
||||||
|
w.Header().Add("Access-Control-Allow-Origin", "http://foo.bar.org")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}),
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
AccessControlAllowOriginList: []string{"*"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Access-Control-Allow-Origin": {"*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS",
|
||||||
|
next: emptyHandler,
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
CustomRequestHeaders: map[string]string{"foo": "bar"},
|
||||||
|
},
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Access-Control-Request-Headers": {"origin"},
|
||||||
|
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
||||||
|
"Origin": {"https://foo.bar.org"},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
mid := NewHeader(test.next, test.cfg)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
req.Header = test.requestHeaders
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, rw.Result().Header)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHeader_customResponseHeaders(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
config map[string]string
|
||||||
|
expected http.Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Test Simple Response",
|
||||||
|
config: map[string]string{
|
||||||
|
"Testing": "foo",
|
||||||
|
"Testing2": "bar",
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Foo": {"bar"},
|
||||||
|
"Testing": {"foo"},
|
||||||
|
"Testing2": {"bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "empty Custom Header",
|
||||||
|
config: map[string]string{
|
||||||
|
"Testing": "foo",
|
||||||
|
"Testing2": "",
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Foo": {"bar"},
|
||||||
|
"Testing": {"foo"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Deleting Custom Header",
|
||||||
|
config: map[string]string{
|
||||||
|
"Testing": "foo",
|
||||||
|
"Foo": "",
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"Testing": {"foo"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Foo", "bar")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
mid := NewHeader(emptyHandler, dynamic.Headers{CustomResponseHeaders: test.config})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, rw.Result().Header)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,15 +5,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||||
"github.com/containous/traefik/v2/pkg/log"
|
"github.com/containous/traefik/v2/pkg/log"
|
||||||
"github.com/containous/traefik/v2/pkg/middlewares"
|
"github.com/containous/traefik/v2/pkg/middlewares"
|
||||||
"github.com/containous/traefik/v2/pkg/tracing"
|
"github.com/containous/traefik/v2/pkg/tracing"
|
||||||
"github.com/opentracing/opentracing-go/ext"
|
"github.com/opentracing/opentracing-go/ext"
|
||||||
"github.com/unrolled/secure"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -55,7 +52,7 @@ func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name strin
|
||||||
|
|
||||||
if hasSecureHeaders {
|
if hasSecureHeaders {
|
||||||
logger.Debugf("Setting up secureHeaders from %v", cfg)
|
logger.Debugf("Setting up secureHeaders from %v", cfg)
|
||||||
handler = newSecure(next, cfg)
|
handler = newSecure(next, cfg, name)
|
||||||
nextHandler = handler
|
nextHandler = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,203 +74,3 @@ func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||||
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
h.handler.ServeHTTP(rw, req)
|
h.handler.ServeHTTP(rw, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
type secureHeader struct {
|
|
||||||
next http.Handler
|
|
||||||
secure *secure.Secure
|
|
||||||
}
|
|
||||||
|
|
||||||
// newSecure constructs a new secure instance with supplied options.
|
|
||||||
func newSecure(next http.Handler, cfg dynamic.Headers) *secureHeader {
|
|
||||||
opt := secure.Options{
|
|
||||||
BrowserXssFilter: cfg.BrowserXSSFilter,
|
|
||||||
ContentTypeNosniff: cfg.ContentTypeNosniff,
|
|
||||||
ForceSTSHeader: cfg.ForceSTSHeader,
|
|
||||||
FrameDeny: cfg.FrameDeny,
|
|
||||||
IsDevelopment: cfg.IsDevelopment,
|
|
||||||
SSLRedirect: cfg.SSLRedirect,
|
|
||||||
SSLForceHost: cfg.SSLForceHost,
|
|
||||||
SSLTemporaryRedirect: cfg.SSLTemporaryRedirect,
|
|
||||||
STSIncludeSubdomains: cfg.STSIncludeSubdomains,
|
|
||||||
STSPreload: cfg.STSPreload,
|
|
||||||
ContentSecurityPolicy: cfg.ContentSecurityPolicy,
|
|
||||||
CustomBrowserXssValue: cfg.CustomBrowserXSSValue,
|
|
||||||
CustomFrameOptionsValue: cfg.CustomFrameOptionsValue,
|
|
||||||
PublicKey: cfg.PublicKey,
|
|
||||||
ReferrerPolicy: cfg.ReferrerPolicy,
|
|
||||||
SSLHost: cfg.SSLHost,
|
|
||||||
AllowedHosts: cfg.AllowedHosts,
|
|
||||||
HostsProxyHeaders: cfg.HostsProxyHeaders,
|
|
||||||
SSLProxyHeaders: cfg.SSLProxyHeaders,
|
|
||||||
STSSeconds: cfg.STSSeconds,
|
|
||||||
FeaturePolicy: cfg.FeaturePolicy,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &secureHeader{
|
|
||||||
next: next,
|
|
||||||
secure: secure.New(opt),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Header is a middleware that helps setup a few basic security features.
|
|
||||||
// A single headerOptions struct can be provided to configure which features should be enabled,
|
|
||||||
// and the ability to override a few of the default values.
|
|
||||||
type Header struct {
|
|
||||||
next http.Handler
|
|
||||||
hasCustomHeaders bool
|
|
||||||
hasCorsHeaders bool
|
|
||||||
headers *dynamic.Headers
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHeader constructs a new header instance from supplied frontend header struct.
|
|
||||||
func NewHeader(next http.Handler, cfg dynamic.Headers) *Header {
|
|
||||||
hasCustomHeaders := cfg.HasCustomHeadersDefined()
|
|
||||||
hasCorsHeaders := cfg.HasCorsHeadersDefined()
|
|
||||||
|
|
||||||
ctx := log.With(context.Background(), log.Str(log.MiddlewareType, typeName))
|
|
||||||
handleDeprecation(ctx, &cfg)
|
|
||||||
|
|
||||||
return &Header{
|
|
||||||
next: next,
|
|
||||||
headers: &cfg,
|
|
||||||
hasCustomHeaders: hasCustomHeaders,
|
|
||||||
hasCorsHeaders: hasCorsHeaders,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
// Handle Cors headers and preflight if configured.
|
|
||||||
if isPreflight := s.processCorsHeaders(rw, req); isPreflight {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.hasCustomHeaders {
|
|
||||||
s.modifyCustomRequestHeaders(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is a next, call it.
|
|
||||||
if s.next != nil {
|
|
||||||
s.next.ServeHTTP(rw, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// modifyCustomRequestHeaders sets or deletes custom request headers.
|
|
||||||
func (s *Header) modifyCustomRequestHeaders(req *http.Request) {
|
|
||||||
// Loop through Custom request headers
|
|
||||||
for header, value := range s.headers.CustomRequestHeaders {
|
|
||||||
switch {
|
|
||||||
case value == "":
|
|
||||||
req.Header.Del(header)
|
|
||||||
|
|
||||||
case strings.EqualFold(header, "Host"):
|
|
||||||
req.Host = value
|
|
||||||
|
|
||||||
default:
|
|
||||||
req.Header.Set(header, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostRequestModifyResponseHeaders set or delete response headers.
|
|
||||||
// This method is called AFTER the response is generated from the backend
|
|
||||||
// and can merge/override headers from the backend response.
|
|
||||||
func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error {
|
|
||||||
// Loop through Custom response headers
|
|
||||||
for header, value := range s.headers.CustomResponseHeaders {
|
|
||||||
if value == "" {
|
|
||||||
res.Header.Del(header)
|
|
||||||
} else {
|
|
||||||
res.Header.Set(header, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if res != nil && res.Request != nil {
|
|
||||||
originHeader := res.Request.Header.Get("Origin")
|
|
||||||
allowed, match := s.isOriginAllowed(originHeader)
|
|
||||||
|
|
||||||
if allowed {
|
|
||||||
res.Header.Set("Access-Control-Allow-Origin", match)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.headers.AccessControlAllowCredentials {
|
|
||||||
res.Header.Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(s.headers.AccessControlExposeHeaders) > 0 {
|
|
||||||
exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",")
|
|
||||||
res.Header.Set("Access-Control-Expose-Headers", exposeHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.headers.AddVaryHeader {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
varyHeader := res.Header.Get("Vary")
|
|
||||||
if varyHeader == "Origin" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if varyHeader != "" {
|
|
||||||
varyHeader += ","
|
|
||||||
}
|
|
||||||
varyHeader += "Origin"
|
|
||||||
|
|
||||||
res.Header.Set("Vary", varyHeader)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// processCorsHeaders processes the incoming request,
|
|
||||||
// and returns if it is a preflight request.
|
|
||||||
// If not a preflight, it handles the preRequestModifyCorsResponseHeaders.
|
|
||||||
func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool {
|
|
||||||
if !s.hasCorsHeaders {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
reqAcMethod := req.Header.Get("Access-Control-Request-Method")
|
|
||||||
originHeader := req.Header.Get("Origin")
|
|
||||||
|
|
||||||
if reqAcMethod != "" && originHeader != "" && req.Method == http.MethodOptions {
|
|
||||||
// If the request is an OPTIONS request with an Access-Control-Request-Method header,
|
|
||||||
// and Origin headers, then it is a CORS preflight request,
|
|
||||||
// and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request
|
|
||||||
if s.headers.AccessControlAllowCredentials {
|
|
||||||
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
}
|
|
||||||
|
|
||||||
allowHeaders := strings.Join(s.headers.AccessControlAllowHeaders, ",")
|
|
||||||
if allowHeaders != "" {
|
|
||||||
rw.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
|
||||||
}
|
|
||||||
|
|
||||||
allowMethods := strings.Join(s.headers.AccessControlAllowMethods, ",")
|
|
||||||
if allowMethods != "" {
|
|
||||||
rw.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
|
||||||
}
|
|
||||||
|
|
||||||
allowed, match := s.isOriginAllowed(originHeader)
|
|
||||||
if allowed {
|
|
||||||
rw.Header().Set("Access-Control-Allow-Origin", match)
|
|
||||||
}
|
|
||||||
|
|
||||||
rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge)))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Header) isOriginAllowed(origin string) (bool, string) {
|
|
||||||
for _, item := range s.headers.AccessControlAllowOriginList {
|
|
||||||
if item == "*" || item == origin {
|
|
||||||
return true, item
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
|
@ -9,104 +9,21 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||||
"github.com/containous/traefik/v2/pkg/testhelpers"
|
|
||||||
"github.com/containous/traefik/v2/pkg/tracing"
|
"github.com/containous/traefik/v2/pkg/tracing"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCustomRequestHeader(t *testing.T) {
|
func TestNew_withoutOptions(t *testing.T) {
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||||
|
|
||||||
header := NewHeader(emptyHandler, dynamic.Headers{
|
mid, err := New(context.Background(), next, dynamic.Headers{}, "testing")
|
||||||
CustomRequestHeaders: map[string]string{
|
require.Errorf(t, err, "headers configuration not valid")
|
||||||
"X-Custom-Request-Header": "test_request",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
res := httptest.NewRecorder()
|
assert.Nil(t, mid)
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
|
||||||
|
|
||||||
header.ServeHTTP(res, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.Code)
|
|
||||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCustomRequestHeader_Host(t *testing.T) {
|
func TestNew_allowedHosts(t *testing.T) {
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
desc string
|
|
||||||
customHeaders map[string]string
|
|
||||||
expectedHost string
|
|
||||||
expectedURLHost string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "standard Host header",
|
|
||||||
customHeaders: map[string]string{},
|
|
||||||
expectedHost: "example.org",
|
|
||||||
expectedURLHost: "example.org",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "custom Host header",
|
|
||||||
customHeaders: map[string]string{
|
|
||||||
"Host": "example.com",
|
|
||||||
},
|
|
||||||
expectedHost: "example.com",
|
|
||||||
expectedURLHost: "example.org",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
header := NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
CustomRequestHeaders: test.customHeaders,
|
|
||||||
})
|
|
||||||
|
|
||||||
res := httptest.NewRecorder()
|
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://example.org/foo", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
header.ServeHTTP(res, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.Code)
|
|
||||||
assert.Equal(t, test.expectedHost, req.Host)
|
|
||||||
assert.Equal(t, test.expectedURLHost, req.URL.Host)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
|
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
|
|
||||||
header := NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
CustomRequestHeaders: map[string]string{
|
|
||||||
"X-Custom-Request-Header": "test_request",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
res := httptest.NewRecorder()
|
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
|
||||||
|
|
||||||
header.ServeHTTP(res, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.Code)
|
|
||||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
|
||||||
|
|
||||||
header = NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
CustomRequestHeaders: map[string]string{
|
|
||||||
"X-Custom-Request-Header": "",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
header.ServeHTTP(res, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.Code)
|
|
||||||
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecureHeader(t *testing.T) {
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
fromHost string
|
fromHost string
|
||||||
|
@ -129,10 +46,13 @@ func TestSecureHeader(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||||
header, err := New(context.Background(), emptyHandler, dynamic.Headers{
|
|
||||||
|
cfg := dynamic.Headers{
|
||||||
AllowedHosts: []string{"foo.com", "bar.com"},
|
AllowedHosts: []string{"foo.com", "bar.com"},
|
||||||
}, "foo")
|
}
|
||||||
|
|
||||||
|
mid, err := New(context.Background(), emptyHandler, cfg, "foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range testCases {
|
for _, test := range testCases {
|
||||||
|
@ -140,479 +60,54 @@ func TestSecureHeader(t *testing.T) {
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
res := httptest.NewRecorder()
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
|
||||||
req.Host = test.fromHost
|
req.Host = test.fromHost
|
||||||
header.ServeHTTP(res, req)
|
|
||||||
assert.Equal(t, test.expected, res.Code)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSSLForceHost(t *testing.T) {
|
|
||||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
_, _ = rw.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
desc string
|
|
||||||
host string
|
|
||||||
secureMiddleware *secureHeader
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "http should return a 301",
|
|
||||||
host: "http://powpow.example.com",
|
|
||||||
secureMiddleware: newSecure(next, dynamic.Headers{
|
|
||||||
SSLRedirect: true,
|
|
||||||
SSLForceHost: true,
|
|
||||||
SSLHost: "powpow.example.com",
|
|
||||||
}),
|
|
||||||
expected: http.StatusMovedPermanently,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "http sub domain should return a 301",
|
|
||||||
host: "http://www.powpow.example.com",
|
|
||||||
secureMiddleware: newSecure(next, dynamic.Headers{
|
|
||||||
SSLRedirect: true,
|
|
||||||
SSLForceHost: true,
|
|
||||||
SSLHost: "powpow.example.com",
|
|
||||||
}),
|
|
||||||
expected: http.StatusMovedPermanently,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "https should return a 200",
|
|
||||||
host: "https://powpow.example.com",
|
|
||||||
secureMiddleware: newSecure(next, dynamic.Headers{
|
|
||||||
SSLRedirect: true,
|
|
||||||
SSLForceHost: true,
|
|
||||||
SSLHost: "powpow.example.com",
|
|
||||||
}),
|
|
||||||
expected: http.StatusOK,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "https sub domain should return a 301",
|
|
||||||
host: "https://www.powpow.example.com",
|
|
||||||
secureMiddleware: newSecure(next, dynamic.Headers{
|
|
||||||
SSLRedirect: true,
|
|
||||||
SSLForceHost: true,
|
|
||||||
SSLHost: "powpow.example.com",
|
|
||||||
}),
|
|
||||||
expected: http.StatusMovedPermanently,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "http without force host and sub domain should return a 301",
|
|
||||||
host: "http://www.powpow.example.com",
|
|
||||||
secureMiddleware: newSecure(next, dynamic.Headers{
|
|
||||||
SSLRedirect: true,
|
|
||||||
SSLForceHost: false,
|
|
||||||
SSLHost: "powpow.example.com",
|
|
||||||
}),
|
|
||||||
expected: http.StatusMovedPermanently,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "https without force host and sub domain should return a 301",
|
|
||||||
host: "https://www.powpow.example.com",
|
|
||||||
secureMiddleware: newSecure(next, dynamic.Headers{
|
|
||||||
SSLRedirect: true,
|
|
||||||
SSLForceHost: false,
|
|
||||||
SSLHost: "powpow.example.com",
|
|
||||||
}),
|
|
||||||
expected: http.StatusOK,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, test.host, nil)
|
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
test.secureMiddleware.ServeHTTP(rw, req)
|
|
||||||
|
|
||||||
assert.Equal(t, test.expected, rw.Result().StatusCode)
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, rw.Code)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSPreflights(t *testing.T) {
|
func TestNew_customHeaders(t *testing.T) {
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||||
|
|
||||||
testCases := []struct {
|
cfg := dynamic.Headers{
|
||||||
desc string
|
|
||||||
header *Header
|
|
||||||
requestHeaders http.Header
|
|
||||||
expected http.Header
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "Test Simple Preflight",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
|
||||||
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
|
||||||
AccessControlMaxAge: 600,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Access-Control-Request-Headers": {"origin"},
|
|
||||||
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
|
||||||
"Access-Control-Max-Age": {"600"},
|
|
||||||
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Wildcard origin Preflight",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
AccessControlMaxAge: 600,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Access-Control-Request-Headers": {"origin"},
|
|
||||||
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
"Access-Control-Max-Age": {"600"},
|
|
||||||
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Allow Credentials Preflight",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
AccessControlAllowCredentials: true,
|
|
||||||
AccessControlMaxAge: 600,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Access-Control-Request-Headers": {"origin"},
|
|
||||||
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
"Access-Control-Max-Age": {"600"},
|
|
||||||
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
|
||||||
"Access-Control-Allow-Credentials": {"true"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Allow Headers Preflight",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"},
|
|
||||||
AccessControlMaxAge: 600,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Access-Control-Request-Headers": {"origin"},
|
|
||||||
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
"Access-Control-Max-Age": {"600"},
|
|
||||||
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
|
||||||
"Access-Control-Allow-Headers": {"origin,X-Forwarded-For"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "No Request Headers Preflight",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"},
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"},
|
|
||||||
AccessControlMaxAge: 600,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
"Access-Control-Max-Age": {"600"},
|
|
||||||
"Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"},
|
|
||||||
"Access-Control-Allow-Headers": {"origin,X-Forwarded-For"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
req := testhelpers.MustNewRequest(http.MethodOptions, "/foo", nil)
|
|
||||||
req.Header = test.requestHeaders
|
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
test.header.ServeHTTP(rw, req)
|
|
||||||
|
|
||||||
assert.Equal(t, test.expected, rw.Result().Header)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEmptyHeaderObject(t *testing.T) {
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
|
|
||||||
_, err := New(context.Background(), next, dynamic.Headers{}, "testing")
|
|
||||||
require.Errorf(t, err, "headers configuration not valid")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCustomHeaderHandler(t *testing.T) {
|
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
|
|
||||||
header, _ := New(context.Background(), next, dynamic.Headers{
|
|
||||||
CustomRequestHeaders: map[string]string{
|
CustomRequestHeaders: map[string]string{
|
||||||
"X-Custom-Request-Header": "test_request",
|
"X-Custom-Request-Header": "test_request",
|
||||||
},
|
},
|
||||||
}, "testing")
|
CustomResponseHeaders: map[string]string{
|
||||||
|
"X-Custom-Response-Header": "test_response",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
res := httptest.NewRecorder()
|
mid, err := New(context.Background(), next, cfg, "testing")
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
require.NoError(t, err)
|
||||||
|
|
||||||
header.ServeHTTP(res, req)
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.Code)
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rw.Code)
|
||||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
||||||
|
assert.Equal(t, "test_response", rw.Header().Get("X-Custom-Response-Header"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTracingInformation(t *testing.T) {
|
func Test_headers_getTracingInformation(t *testing.T) {
|
||||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
header := &headers{
|
mid := &headers{
|
||||||
handler: next,
|
handler: next,
|
||||||
name: "testing",
|
name: "testing",
|
||||||
}
|
}
|
||||||
|
|
||||||
name, trace := header.GetTracingInformation()
|
name, trace := mid.GetTracingInformation()
|
||||||
|
|
||||||
assert.Equal(t, "testing", name)
|
assert.Equal(t, "testing", name)
|
||||||
assert.Equal(t, tracing.SpanKindNoneEnum, trace)
|
assert.Equal(t, tracing.SpanKindNoneEnum, trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCORSResponses(t *testing.T) {
|
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") })
|
|
||||||
existingOriginHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Origin") })
|
|
||||||
existingAccessControlAllowOriginHandlerSet := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "http://foo.bar.org")
|
|
||||||
})
|
|
||||||
existingAccessControlAllowOriginHandlerAdd := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Add("Access-Control-Allow-Origin", "http://foo.bar.org")
|
|
||||||
})
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
desc string
|
|
||||||
header *Header
|
|
||||||
requestHeaders http.Header
|
|
||||||
expected http.Header
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "Test Simple Request",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Wildcard origin Request",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Empty origin Request",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{},
|
|
||||||
expected: map[string][]string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Not Defined origin Request",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{}),
|
|
||||||
requestHeaders: map[string][]string{},
|
|
||||||
expected: map[string][]string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Allow Credentials Request",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
AccessControlAllowCredentials: true,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
"Access-Control-Allow-Credentials": {"true"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Expose Headers Request",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
AccessControlExposeHeaders: []string{"origin", "X-Forwarded-For"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
"Access-Control-Expose-Headers": {"origin,X-Forwarded-For"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Test Simple Request with Vary Headers",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
|
||||||
AddVaryHeader: true,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
|
||||||
"Vary": {"Origin"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Test Simple Request with Vary Headers and non-empty response",
|
|
||||||
header: NewHeader(nonEmptyHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
|
||||||
AddVaryHeader: true,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
|
||||||
"Vary": {"Testing,Origin"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Test Simple Request with Vary Headers and existing vary:origin response",
|
|
||||||
header: NewHeader(existingOriginHandler, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"https://foo.bar.org"},
|
|
||||||
AddVaryHeader: true,
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"https://foo.bar.org"},
|
|
||||||
"Vary": {"Origin"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Test Simple Request with non-empty response: set ACAO",
|
|
||||||
header: NewHeader(existingAccessControlAllowOriginHandlerSet, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Test Simple Request with non-empty response: add ACAO",
|
|
||||||
header: NewHeader(existingAccessControlAllowOriginHandlerAdd, dynamic.Headers{
|
|
||||||
AccessControlAllowOriginList: []string{"*"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Access-Control-Allow-Origin": {"*"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
CustomRequestHeaders: map[string]string{"foo": "bar"},
|
|
||||||
}),
|
|
||||||
requestHeaders: map[string][]string{
|
|
||||||
"Access-Control-Request-Headers": {"origin"},
|
|
||||||
"Access-Control-Request-Method": {"GET", "OPTIONS"},
|
|
||||||
"Origin": {"https://foo.bar.org"},
|
|
||||||
},
|
|
||||||
expected: map[string][]string{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
|
||||||
req.Header = test.requestHeaders
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
test.header.ServeHTTP(rw, req)
|
|
||||||
res := rw.Result()
|
|
||||||
res.Request = req
|
|
||||||
err := test.header.PostRequestModifyResponseHeaders(res)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, test.expected, rw.Result().Header)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCustomResponseHeaders(t *testing.T) {
|
|
||||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
desc string
|
|
||||||
header *Header
|
|
||||||
expected http.Header
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "Test Simple Response",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{
|
|
||||||
"Testing": "foo",
|
|
||||||
"Testing2": "bar",
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Testing": {"foo"},
|
|
||||||
"Testing2": {"bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "Deleting Custom Header",
|
|
||||||
header: NewHeader(emptyHandler, dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{
|
|
||||||
"Testing": "foo",
|
|
||||||
"Testing2": "",
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
expected: map[string][]string{
|
|
||||||
"Testing": {"foo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
|
||||||
rw := httptest.NewRecorder()
|
|
||||||
test.header.ServeHTTP(rw, req)
|
|
||||||
err := test.header.PostRequestModifyResponseHeaders(rw.Result())
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, test.expected, rw.Result().Header)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
75
pkg/middlewares/headers/responsewriter.go
Normal file
75
pkg/middlewares/headers/responsewriter.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
package headers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/containous/traefik/v2/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type responseModifier struct {
|
||||||
|
r *http.Request
|
||||||
|
w http.ResponseWriter
|
||||||
|
|
||||||
|
headersSent bool // whether headers have already been sent
|
||||||
|
code int // status code, must default to 200
|
||||||
|
|
||||||
|
modifier func(*http.Response) error // can be nil
|
||||||
|
modified bool // whether modifier has already been called for the current request
|
||||||
|
modifierErr error // returned by modifier call
|
||||||
|
}
|
||||||
|
|
||||||
|
// modifier can be nil.
|
||||||
|
func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier {
|
||||||
|
return &responseModifier{
|
||||||
|
r: r,
|
||||||
|
w: w,
|
||||||
|
modifier: modifier,
|
||||||
|
code: http.StatusOK,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseModifier) WriteHeader(code int) {
|
||||||
|
if w.headersSent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
w.code = code
|
||||||
|
w.headersSent = true
|
||||||
|
}()
|
||||||
|
|
||||||
|
if w.modifier == nil || w.modified {
|
||||||
|
w.w.WriteHeader(code)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := http.Response{
|
||||||
|
Header: w.w.Header(),
|
||||||
|
Request: w.r,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.modifier(&resp); err != nil {
|
||||||
|
w.modifierErr = err
|
||||||
|
// we are propagating when we are called in Write, but we're logging anyway,
|
||||||
|
// because we could be called from another place which does not take care of
|
||||||
|
// checking w.modifierErr.
|
||||||
|
log.WithoutContext().Errorf("Error when applying response modifier: %v", err)
|
||||||
|
w.w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.modified = true
|
||||||
|
w.w.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseModifier) Header() http.Header {
|
||||||
|
return w.w.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseModifier) Write(b []byte) (int, error) {
|
||||||
|
w.WriteHeader(w.code)
|
||||||
|
if w.modifierErr != nil {
|
||||||
|
return 0, w.modifierErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.w.Write(b)
|
||||||
|
}
|
54
pkg/middlewares/headers/secure.go
Normal file
54
pkg/middlewares/headers/secure.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package headers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||||
|
"github.com/unrolled/secure"
|
||||||
|
)
|
||||||
|
|
||||||
|
type secureHeader struct {
|
||||||
|
next http.Handler
|
||||||
|
secure *secure.Secure
|
||||||
|
cfg dynamic.Headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSecure constructs a new secure instance with supplied options.
|
||||||
|
func newSecure(next http.Handler, cfg dynamic.Headers, contextKey string) *secureHeader {
|
||||||
|
opt := secure.Options{
|
||||||
|
BrowserXssFilter: cfg.BrowserXSSFilter,
|
||||||
|
ContentTypeNosniff: cfg.ContentTypeNosniff,
|
||||||
|
ForceSTSHeader: cfg.ForceSTSHeader,
|
||||||
|
FrameDeny: cfg.FrameDeny,
|
||||||
|
IsDevelopment: cfg.IsDevelopment,
|
||||||
|
SSLRedirect: cfg.SSLRedirect,
|
||||||
|
SSLForceHost: cfg.SSLForceHost,
|
||||||
|
SSLTemporaryRedirect: cfg.SSLTemporaryRedirect,
|
||||||
|
STSIncludeSubdomains: cfg.STSIncludeSubdomains,
|
||||||
|
STSPreload: cfg.STSPreload,
|
||||||
|
ContentSecurityPolicy: cfg.ContentSecurityPolicy,
|
||||||
|
CustomBrowserXssValue: cfg.CustomBrowserXSSValue,
|
||||||
|
CustomFrameOptionsValue: cfg.CustomFrameOptionsValue,
|
||||||
|
PublicKey: cfg.PublicKey,
|
||||||
|
ReferrerPolicy: cfg.ReferrerPolicy,
|
||||||
|
SSLHost: cfg.SSLHost,
|
||||||
|
AllowedHosts: cfg.AllowedHosts,
|
||||||
|
HostsProxyHeaders: cfg.HostsProxyHeaders,
|
||||||
|
SSLProxyHeaders: cfg.SSLProxyHeaders,
|
||||||
|
STSSeconds: cfg.STSSeconds,
|
||||||
|
FeaturePolicy: cfg.FeaturePolicy,
|
||||||
|
SecureContextKey: contextKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &secureHeader{
|
||||||
|
next: next,
|
||||||
|
secure: secure.New(opt),
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, func(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
s.next.ServeHTTP(newResponseModifier(writer, request, s.secure.ModifyResponseHeaders), request)
|
||||||
|
})
|
||||||
|
}
|
191
pkg/middlewares/headers/secure_test.go
Normal file
191
pkg/middlewares/headers/secure_test.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
package headers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Middleware tests based on https://github.com/unrolled/secure
|
||||||
|
|
||||||
|
func Test_newSecure_sslForceHost(t *testing.T) {
|
||||||
|
type expected struct {
|
||||||
|
statusCode int
|
||||||
|
location string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
host string
|
||||||
|
cfg dynamic.Headers
|
||||||
|
expected
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "http should return a 301",
|
||||||
|
host: "http://powpow.example.com",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
SSLRedirect: true,
|
||||||
|
SSLForceHost: true,
|
||||||
|
SSLHost: "powpow.example.com",
|
||||||
|
},
|
||||||
|
expected: expected{
|
||||||
|
statusCode: http.StatusMovedPermanently,
|
||||||
|
location: "https://powpow.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "http sub domain should return a 301",
|
||||||
|
host: "http://www.powpow.example.com",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
SSLRedirect: true,
|
||||||
|
SSLForceHost: true,
|
||||||
|
SSLHost: "powpow.example.com",
|
||||||
|
},
|
||||||
|
expected: expected{
|
||||||
|
statusCode: http.StatusMovedPermanently,
|
||||||
|
location: "https://powpow.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "https should return a 200",
|
||||||
|
host: "https://powpow.example.com",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
SSLRedirect: true,
|
||||||
|
SSLForceHost: true,
|
||||||
|
SSLHost: "powpow.example.com",
|
||||||
|
},
|
||||||
|
expected: expected{statusCode: http.StatusOK},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "https sub domain should return a 301",
|
||||||
|
host: "https://www.powpow.example.com",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
SSLRedirect: true,
|
||||||
|
SSLForceHost: true,
|
||||||
|
SSLHost: "powpow.example.com",
|
||||||
|
},
|
||||||
|
expected: expected{
|
||||||
|
statusCode: http.StatusMovedPermanently,
|
||||||
|
location: "https://powpow.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "http without force host and sub domain should return a 301",
|
||||||
|
host: "http://www.powpow.example.com",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
SSLRedirect: true,
|
||||||
|
SSLForceHost: false,
|
||||||
|
SSLHost: "powpow.example.com",
|
||||||
|
},
|
||||||
|
expected: expected{
|
||||||
|
statusCode: http.StatusMovedPermanently,
|
||||||
|
location: "https://powpow.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "https without force host and sub domain should return a 301",
|
||||||
|
host: "https://www.powpow.example.com",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
SSLRedirect: true,
|
||||||
|
SSLForceHost: false,
|
||||||
|
SSLHost: "powpow.example.com",
|
||||||
|
},
|
||||||
|
expected: expected{statusCode: http.StatusOK},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
_, _ = rw.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
mid := newSecure(next, test.cfg, "mymiddleware")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, test.host, nil)
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
mid.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected.statusCode, rw.Result().StatusCode)
|
||||||
|
assert.Equal(t, test.expected.location, rw.Header().Get("Location"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newSecure_modifyResponse(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg dynamic.Headers
|
||||||
|
expected http.Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "FeaturePolicy",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
FeaturePolicy: "vibrate 'none';",
|
||||||
|
},
|
||||||
|
expected: http.Header{"Feature-Policy": []string{"vibrate 'none';"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "STSSeconds",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
STSSeconds: 1,
|
||||||
|
ForceSTSHeader: true,
|
||||||
|
},
|
||||||
|
expected: http.Header{"Strict-Transport-Security": []string{"max-age=1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "STSSeconds and STSPreload",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
STSSeconds: 1,
|
||||||
|
ForceSTSHeader: true,
|
||||||
|
STSPreload: true,
|
||||||
|
},
|
||||||
|
expected: http.Header{"Strict-Transport-Security": []string{"max-age=1; preload"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "CustomFrameOptionsValue",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
CustomFrameOptionsValue: "foo",
|
||||||
|
},
|
||||||
|
expected: http.Header{"X-Frame-Options": []string{"foo"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "FrameDeny",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
FrameDeny: true,
|
||||||
|
},
|
||||||
|
expected: http.Header{"X-Frame-Options": []string{"DENY"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ContentTypeNosniff",
|
||||||
|
cfg: dynamic.Headers{
|
||||||
|
ContentTypeNosniff: true,
|
||||||
|
},
|
||||||
|
expected: http.Header{"X-Content-Type-Options": []string{"nosniff"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
secure := newSecure(emptyHandler, test.cfg, "mymiddleware")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
secure.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, rw.Result().Header)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,53 +0,0 @@
|
||||||
package responsemodifiers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
|
||||||
"github.com/containous/traefik/v2/pkg/middlewares/headers"
|
|
||||||
"github.com/unrolled/secure"
|
|
||||||
)
|
|
||||||
|
|
||||||
func buildHeaders(hdrs *dynamic.Headers) func(*http.Response) error {
|
|
||||||
opt := secure.Options{
|
|
||||||
BrowserXssFilter: hdrs.BrowserXSSFilter,
|
|
||||||
ContentTypeNosniff: hdrs.ContentTypeNosniff,
|
|
||||||
ForceSTSHeader: hdrs.ForceSTSHeader,
|
|
||||||
FrameDeny: hdrs.FrameDeny,
|
|
||||||
IsDevelopment: hdrs.IsDevelopment,
|
|
||||||
SSLRedirect: hdrs.SSLRedirect,
|
|
||||||
SSLForceHost: hdrs.SSLForceHost,
|
|
||||||
SSLTemporaryRedirect: hdrs.SSLTemporaryRedirect,
|
|
||||||
STSIncludeSubdomains: hdrs.STSIncludeSubdomains,
|
|
||||||
STSPreload: hdrs.STSPreload,
|
|
||||||
ContentSecurityPolicy: hdrs.ContentSecurityPolicy,
|
|
||||||
CustomBrowserXssValue: hdrs.CustomBrowserXSSValue,
|
|
||||||
CustomFrameOptionsValue: hdrs.CustomFrameOptionsValue,
|
|
||||||
PublicKey: hdrs.PublicKey,
|
|
||||||
ReferrerPolicy: hdrs.ReferrerPolicy,
|
|
||||||
SSLHost: hdrs.SSLHost,
|
|
||||||
AllowedHosts: hdrs.AllowedHosts,
|
|
||||||
HostsProxyHeaders: hdrs.HostsProxyHeaders,
|
|
||||||
SSLProxyHeaders: hdrs.SSLProxyHeaders,
|
|
||||||
STSSeconds: hdrs.STSSeconds,
|
|
||||||
FeaturePolicy: hdrs.FeaturePolicy,
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(resp *http.Response) error {
|
|
||||||
if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() {
|
|
||||||
err := headers.NewHeader(nil, *hdrs).PostRequestModifyResponseHeaders(resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdrs.HasSecureHeadersDefined() {
|
|
||||||
err := secure.New(opt).ModifyResponseHeaders(resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,13 +0,0 @@
|
||||||
package responsemodifiers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/log"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// getLogger creates a logger configured with the middleware fields.
|
|
||||||
func getLogger(ctx context.Context, middleware, middlewareType string) logrus.FieldLogger {
|
|
||||||
return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType)
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
package responsemodifiers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/runtime"
|
|
||||||
"github.com/containous/traefik/v2/pkg/server/provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewBuilder creates a builder.
|
|
||||||
func NewBuilder(configs map[string]*runtime.MiddlewareInfo) *Builder {
|
|
||||||
return &Builder{configs: configs}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Builder holds builder configuration.
|
|
||||||
type Builder struct {
|
|
||||||
configs map[string]*runtime.MiddlewareInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build Builds the response modifier.
|
|
||||||
// It returns nil if there is no modifier to apply.
|
|
||||||
func (f *Builder) Build(ctx context.Context, names []string) func(*http.Response) error {
|
|
||||||
var modifiers []func(*http.Response) error
|
|
||||||
|
|
||||||
for _, middleName := range names {
|
|
||||||
conf, ok := f.configs[middleName]
|
|
||||||
if !ok {
|
|
||||||
getLogger(ctx, middleName, "undefined").Debug("Middleware name not found in config (ResponseModifier)")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if conf == nil || conf.Middleware == nil {
|
|
||||||
getLogger(ctx, middleName, "undefined").Error("Invalid Middleware configuration (ResponseModifier)")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.Headers != nil {
|
|
||||||
getLogger(ctx, middleName, "Headers").Debug("Creating Middleware (ResponseModifier)")
|
|
||||||
|
|
||||||
modifiers = append(modifiers, buildHeaders(conf.Headers))
|
|
||||||
} else if conf.Chain != nil {
|
|
||||||
chainCtx := provider.AddInContext(ctx, middleName)
|
|
||||||
getLogger(chainCtx, middleName, "Chain").Debug("Creating Middleware (ResponseModifier)")
|
|
||||||
var qualifiedNames []string
|
|
||||||
for _, name := range conf.Chain.Middlewares {
|
|
||||||
qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(chainCtx, name))
|
|
||||||
}
|
|
||||||
|
|
||||||
if rm := f.Build(ctx, qualifiedNames); rm != nil {
|
|
||||||
modifiers = append(modifiers, rm)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modifiers) > 0 {
|
|
||||||
return func(resp *http.Response) error {
|
|
||||||
for i := len(modifiers); i > 0; i-- {
|
|
||||||
err := modifiers[i-1](resp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,214 +0,0 @@
|
||||||
package responsemodifiers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/runtime"
|
|
||||||
"github.com/containous/traefik/v2/pkg/middlewares/headers"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func stubResponse(_ map[string]*dynamic.Middleware) *http.Response {
|
|
||||||
return &http.Response{Header: make(http.Header)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuilderBuild(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
desc string
|
|
||||||
middlewares []string
|
|
||||||
// buildResponse is needed because secure use a private context key
|
|
||||||
buildResponse func(map[string]*dynamic.Middleware) *http.Response
|
|
||||||
conf map[string]*dynamic.Middleware
|
|
||||||
assertResponse func(*testing.T, *http.Response)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "no configuration",
|
|
||||||
middlewares: []string{"foo", "bar"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "one modifier",
|
|
||||||
middlewares: []string{"foo", "bar"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Foo": "foo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
assert.Equal(t, "foo", resp.Header.Get("X-Foo"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "secure: one modifier",
|
|
||||||
middlewares: []string{"foo", "bar"},
|
|
||||||
buildResponse: func(middlewares map[string]*dynamic.Middleware) *http.Response {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
var request *http.Request
|
|
||||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
request = req
|
|
||||||
})
|
|
||||||
|
|
||||||
headerM := *middlewares["foo"].Headers
|
|
||||||
handler, err := headers.New(ctx, next, headerM, "secure")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
handler.ServeHTTP(httptest.NewRecorder(),
|
|
||||||
httptest.NewRequest(http.MethodGet, "http://foo.com", nil))
|
|
||||||
|
|
||||||
return &http.Response{Header: make(http.Header), Request: request}
|
|
||||||
},
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
ReferrerPolicy: "no-referrer",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"bar": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Bar": "bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
assert.Equal(t, "no-referrer", resp.Header.Get("Referrer-Policy"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "two modifiers",
|
|
||||||
middlewares: []string{"foo", "bar"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Foo": "foo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"bar": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Bar": "bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
assert.Equal(t, "foo", resp.Header.Get("X-Foo"))
|
|
||||||
assert.Equal(t, "bar", resp.Header.Get("X-Bar"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "modifier order",
|
|
||||||
middlewares: []string{"foo", "bar"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Foo": "foo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"bar": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Foo": "bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
assert.Equal(t, "foo", resp.Header.Get("X-Foo"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "chain",
|
|
||||||
middlewares: []string{"chain"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Foo": "foo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"bar": {
|
|
||||||
Headers: &dynamic.Headers{
|
|
||||||
CustomResponseHeaders: map[string]string{"X-Foo": "bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"chain": {
|
|
||||||
Chain: &dynamic.Chain{
|
|
||||||
Middlewares: []string{"foo", "bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
assert.Equal(t, "foo", resp.Header.Get("X-Foo"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "nil middleware",
|
|
||||||
middlewares: []string{"foo"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": nil,
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
desc: "chain without headers",
|
|
||||||
middlewares: []string{"chain"},
|
|
||||||
buildResponse: stubResponse,
|
|
||||||
conf: map[string]*dynamic.Middleware{
|
|
||||||
"foo": {IPWhiteList: &dynamic.IPWhiteList{}},
|
|
||||||
"chain": {
|
|
||||||
Chain: &dynamic.Chain{
|
|
||||||
Middlewares: []string{"foo"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertResponse: func(t *testing.T, resp *http.Response) {},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
test := test
|
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
rtConf := runtime.NewConfig(dynamic.Configuration{
|
|
||||||
HTTP: &dynamic.HTTPConfiguration{
|
|
||||||
Middlewares: test.conf,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
builder := NewBuilder(rtConf.Middlewares)
|
|
||||||
|
|
||||||
rm := builder.Build(context.Background(), test.middlewares)
|
|
||||||
if rm == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := test.buildResponse(test.conf)
|
|
||||||
|
|
||||||
err := rm(resp)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
test.assertResponse(t, resp)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -44,7 +44,7 @@ type Builder struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type serviceBuilder interface {
|
type serviceBuilder interface {
|
||||||
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
BuildHTTP(ctx context.Context, serviceName string) (http.Handler, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBuilder creates a new Builder.
|
// NewBuilder creates a new Builder.
|
||||||
|
|
|
@ -24,12 +24,8 @@ type middlewareBuilder interface {
|
||||||
BuildChain(ctx context.Context, names []string) *alice.Chain
|
BuildChain(ctx context.Context, names []string) *alice.Chain
|
||||||
}
|
}
|
||||||
|
|
||||||
type responseModifierBuilder interface {
|
|
||||||
Build(ctx context.Context, names []string) func(*http.Response) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type serviceManager interface {
|
type serviceManager interface {
|
||||||
BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error)
|
||||||
LaunchHealthCheck()
|
LaunchHealthCheck()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,22 +35,15 @@ type Manager struct {
|
||||||
serviceManager serviceManager
|
serviceManager serviceManager
|
||||||
middlewaresBuilder middlewareBuilder
|
middlewaresBuilder middlewareBuilder
|
||||||
chainBuilder *middleware.ChainBuilder
|
chainBuilder *middleware.ChainBuilder
|
||||||
modifierBuilder responseModifierBuilder
|
|
||||||
conf *runtime.Configuration
|
conf *runtime.Configuration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager Creates a new Manager.
|
// NewManager Creates a new Manager.
|
||||||
func NewManager(conf *runtime.Configuration,
|
func NewManager(conf *runtime.Configuration, serviceManager serviceManager, middlewaresBuilder middlewareBuilder, chainBuilder *middleware.ChainBuilder) *Manager {
|
||||||
serviceManager serviceManager,
|
|
||||||
middlewaresBuilder middlewareBuilder,
|
|
||||||
modifierBuilder responseModifierBuilder,
|
|
||||||
chainBuilder *middleware.ChainBuilder,
|
|
||||||
) *Manager {
|
|
||||||
return &Manager{
|
return &Manager{
|
||||||
routerHandlers: make(map[string]http.Handler),
|
routerHandlers: make(map[string]http.Handler),
|
||||||
serviceManager: serviceManager,
|
serviceManager: serviceManager,
|
||||||
middlewaresBuilder: middlewaresBuilder,
|
middlewaresBuilder: middlewaresBuilder,
|
||||||
modifierBuilder: modifierBuilder,
|
|
||||||
chainBuilder: chainBuilder,
|
chainBuilder: chainBuilder,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
}
|
}
|
||||||
|
@ -176,13 +165,12 @@ func (m *Manager) buildHTTPHandler(ctx context.Context, router *runtime.RouterIn
|
||||||
qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(ctx, name))
|
qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(ctx, name))
|
||||||
}
|
}
|
||||||
router.Middlewares = qualifiedNames
|
router.Middlewares = qualifiedNames
|
||||||
rm := m.modifierBuilder.Build(ctx, qualifiedNames)
|
|
||||||
|
|
||||||
if router.Service == "" {
|
if router.Service == "" {
|
||||||
return nil, errors.New("the service is missing on the router")
|
return nil, errors.New("the service is missing on the router")
|
||||||
}
|
}
|
||||||
|
|
||||||
sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service, rm)
|
sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/containous/traefik/v2/pkg/config/static"
|
"github.com/containous/traefik/v2/pkg/config/static"
|
||||||
"github.com/containous/traefik/v2/pkg/middlewares/accesslog"
|
"github.com/containous/traefik/v2/pkg/middlewares/accesslog"
|
||||||
"github.com/containous/traefik/v2/pkg/middlewares/requestdecorator"
|
"github.com/containous/traefik/v2/pkg/middlewares/requestdecorator"
|
||||||
"github.com/containous/traefik/v2/pkg/responsemodifiers"
|
|
||||||
"github.com/containous/traefik/v2/pkg/server/middleware"
|
"github.com/containous/traefik/v2/pkg/server/middleware"
|
||||||
"github.com/containous/traefik/v2/pkg/server/service"
|
"github.com/containous/traefik/v2/pkg/server/service"
|
||||||
"github.com/containous/traefik/v2/pkg/testhelpers"
|
"github.com/containous/traefik/v2/pkg/testhelpers"
|
||||||
|
@ -290,10 +289,9 @@ func TestRouterManager_Get(t *testing.T) {
|
||||||
|
|
||||||
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
||||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
||||||
responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares)
|
|
||||||
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
||||||
|
|
||||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder)
|
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder)
|
||||||
|
|
||||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
||||||
|
|
||||||
|
@ -395,10 +393,9 @@ func TestAccessLog(t *testing.T) {
|
||||||
|
|
||||||
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
||||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
||||||
responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares)
|
|
||||||
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
||||||
|
|
||||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder)
|
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder)
|
||||||
|
|
||||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
||||||
|
|
||||||
|
@ -683,10 +680,9 @@ func TestRuntimeConfiguration(t *testing.T) {
|
||||||
|
|
||||||
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
||||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
||||||
responseModifierFactory := responsemodifiers.NewBuilder(map[string]*runtime.MiddlewareInfo{})
|
|
||||||
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
||||||
|
|
||||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder)
|
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder)
|
||||||
|
|
||||||
_ = routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
_ = routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
||||||
|
|
||||||
|
@ -765,10 +761,9 @@ func TestProviderOnMiddlewares(t *testing.T) {
|
||||||
|
|
||||||
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil)
|
||||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
||||||
responseModifierFactory := responsemodifiers.NewBuilder(map[string]*runtime.MiddlewareInfo{})
|
|
||||||
chainBuilder := middleware.NewChainBuilder(staticCfg, nil, nil)
|
chainBuilder := middleware.NewChainBuilder(staticCfg, nil, nil)
|
||||||
|
|
||||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder)
|
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder)
|
||||||
|
|
||||||
_ = routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
_ = routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
||||||
|
|
||||||
|
@ -826,10 +821,9 @@ func BenchmarkRouterServe(b *testing.B) {
|
||||||
|
|
||||||
serviceManager := service.NewManager(rtConf.Services, &staticTransport{res}, nil, nil)
|
serviceManager := service.NewManager(rtConf.Services, &staticTransport{res}, nil, nil)
|
||||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
||||||
responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares)
|
|
||||||
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil)
|
||||||
|
|
||||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder)
|
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder)
|
||||||
|
|
||||||
handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
||||||
|
|
||||||
|
@ -871,7 +865,7 @@ func BenchmarkService(b *testing.B) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||||
|
|
||||||
handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service", nil)
|
handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service")
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/containous/traefik/v2/pkg/config/runtime"
|
"github.com/containous/traefik/v2/pkg/config/runtime"
|
||||||
"github.com/containous/traefik/v2/pkg/config/static"
|
"github.com/containous/traefik/v2/pkg/config/static"
|
||||||
"github.com/containous/traefik/v2/pkg/log"
|
"github.com/containous/traefik/v2/pkg/log"
|
||||||
"github.com/containous/traefik/v2/pkg/responsemodifiers"
|
|
||||||
"github.com/containous/traefik/v2/pkg/server/middleware"
|
"github.com/containous/traefik/v2/pkg/server/middleware"
|
||||||
"github.com/containous/traefik/v2/pkg/server/router"
|
"github.com/containous/traefik/v2/pkg/server/router"
|
||||||
routertcp "github.com/containous/traefik/v2/pkg/server/router/tcp"
|
routertcp "github.com/containous/traefik/v2/pkg/server/router/tcp"
|
||||||
|
@ -67,9 +66,8 @@ func (f *RouterFactory) CreateRouters(conf dynamic.Configuration) (map[string]*t
|
||||||
serviceManager := f.managerFactory.Build(rtConf)
|
serviceManager := f.managerFactory.Build(rtConf)
|
||||||
|
|
||||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager)
|
||||||
responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares)
|
|
||||||
|
|
||||||
routerManager := router.NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, f.chainBuilder)
|
routerManager := router.NewManager(rtConf, serviceManager, middlewaresBuilder, f.chainBuilder)
|
||||||
|
|
||||||
handlersNonTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, false)
|
handlersNonTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, false)
|
||||||
handlersTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, true)
|
handlersTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, true)
|
||||||
|
|
|
@ -8,11 +8,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/config/runtime"
|
"github.com/containous/traefik/v2/pkg/config/runtime"
|
||||||
"github.com/containous/traefik/v2/pkg/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type serviceManager interface {
|
type serviceManager interface {
|
||||||
BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error)
|
||||||
LaunchHealthCheck()
|
LaunchHealthCheck()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,87 +42,18 @@ func NewInternalHandlers(api func(configuration *runtime.Configuration) http.Han
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type responseModifier struct {
|
|
||||||
r *http.Request
|
|
||||||
w http.ResponseWriter
|
|
||||||
|
|
||||||
headersSent bool // whether headers have already been sent
|
|
||||||
code int // status code, must default to 200
|
|
||||||
|
|
||||||
modifier func(*http.Response) error // can be nil
|
|
||||||
modified bool // whether modifier has already been called for the current request
|
|
||||||
modifierErr error // returned by modifier call
|
|
||||||
}
|
|
||||||
|
|
||||||
// modifier can be nil.
|
|
||||||
func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier {
|
|
||||||
return &responseModifier{
|
|
||||||
r: r,
|
|
||||||
w: w,
|
|
||||||
modifier: modifier,
|
|
||||||
code: http.StatusOK,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseModifier) WriteHeader(code int) {
|
|
||||||
if w.headersSent {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
w.code = code
|
|
||||||
w.headersSent = true
|
|
||||||
}()
|
|
||||||
|
|
||||||
if w.modifier == nil || w.modified {
|
|
||||||
w.w.WriteHeader(code)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := http.Response{
|
|
||||||
Header: w.w.Header(),
|
|
||||||
Request: w.r,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.modifier(&resp); err != nil {
|
|
||||||
w.modifierErr = err
|
|
||||||
// we are propagating when we are called in Write, but we're logging anyway,
|
|
||||||
// because we could be called from another place which does not take care of
|
|
||||||
// checking w.modifierErr.
|
|
||||||
log.Errorf("Error when applying response modifier: %v", err)
|
|
||||||
w.w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.modified = true
|
|
||||||
w.w.WriteHeader(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseModifier) Header() http.Header {
|
|
||||||
return w.w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseModifier) Write(b []byte) (int, error) {
|
|
||||||
w.WriteHeader(w.code)
|
|
||||||
if w.modifierErr != nil {
|
|
||||||
return 0, w.modifierErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.w.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildHTTP builds an HTTP handler.
|
// BuildHTTP builds an HTTP handler.
|
||||||
func (m *InternalHandlers) BuildHTTP(rootCtx context.Context, serviceName string, respModifier func(*http.Response) error) (http.Handler, error) {
|
func (m *InternalHandlers) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) {
|
||||||
if !strings.HasSuffix(serviceName, "@internal") {
|
if !strings.HasSuffix(serviceName, "@internal") {
|
||||||
return m.serviceManager.BuildHTTP(rootCtx, serviceName, respModifier)
|
return m.serviceManager.BuildHTTP(rootCtx, serviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
internalHandler, err := m.get(serviceName)
|
internalHandler, err := m.get(serviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
internalHandler.ServeHTTP(newResponseModifier(w, r, respModifier), r)
|
return internalHandler, nil
|
||||||
}), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *InternalHandlers) get(serviceName string) (http.Handler, error) {
|
func (m *InternalHandlers) get(serviceName string) (http.Handler, error) {
|
||||||
|
|
|
@ -21,7 +21,7 @@ const StatusClientClosedRequest = 499
|
||||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
|
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
|
||||||
const StatusClientClosedRequestText = "Client Closed Request"
|
const StatusClientClosedRequestText = "Client Closed Request"
|
||||||
|
|
||||||
func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
|
func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool) (http.Handler, error) {
|
||||||
var flushInterval types.Duration
|
var flushInterval types.Duration
|
||||||
if responseForwarding != nil {
|
if responseForwarding != nil {
|
||||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||||
|
@ -76,10 +76,9 @@ func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwar
|
||||||
delete(outReq.Header, "Sec-Websocket-Protocol")
|
delete(outReq.Header, "Sec-Websocket-Protocol")
|
||||||
delete(outReq.Header, "Sec-Websocket-Version")
|
delete(outReq.Header, "Sec-Websocket-Version")
|
||||||
},
|
},
|
||||||
Transport: defaultRoundTripper,
|
Transport: defaultRoundTripper,
|
||||||
FlushInterval: time.Duration(flushInterval),
|
FlushInterval: time.Duration(flushInterval),
|
||||||
ModifyResponse: responseModifier,
|
BufferPool: bufferPool,
|
||||||
BufferPool: bufferPool,
|
|
||||||
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
|
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
|
||||||
statusCode := http.StatusInternalServerError
|
statusCode := http.StatusInternalServerError
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ func BenchmarkProxy(b *testing.B) {
|
||||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||||
|
|
||||||
pool := newBufferPool()
|
pool := newBufferPool()
|
||||||
handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool, nil)
|
handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool)
|
||||||
|
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|
|
@ -20,7 +20,7 @@ import (
|
||||||
func Bool(v bool) *bool { return &v }
|
func Bool(v bool) *bool { return &v }
|
||||||
|
|
||||||
func TestWebSocketTCPClose(t *testing.T) {
|
func TestWebSocketTCPClose(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
@ -59,7 +59,7 @@ func TestWebSocketTCPClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketPingPong(t *testing.T) {
|
func TestWebSocketPingPong(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ func TestWebSocketPingPong(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketEcho(t *testing.T) {
|
func TestWebSocketEcho(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
@ -191,7 +191,7 @@ func TestWebSocketPassHost(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range testCases {
|
for _, test := range testCases {
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ func TestWebSocketPassHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
||||||
|
@ -291,7 +291,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
upgrader := gorillawebsocket.Upgrader{}
|
||||||
|
@ -337,7 +337,7 @@ func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
upgrader := gorillawebsocket.Upgrader{}
|
||||||
|
@ -377,7 +377,7 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
@ -409,7 +409,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
upgrader := gorillawebsocket.Upgrader{}
|
||||||
|
@ -449,7 +449,7 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
@ -499,7 +499,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||||
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
@ -555,7 +555,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
srv := createTLSWebsocketServer()
|
srv := createTLSWebsocketServer()
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
|
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
|
||||||
|
@ -574,7 +574,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil, nil)
|
forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
|
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
|
||||||
|
@ -593,7 +593,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
|
|
||||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil)
|
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
|
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
|
||||||
|
|
|
@ -62,7 +62,7 @@ type Manager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildHTTP Creates a http.Handler for a service configuration.
|
// BuildHTTP Creates a http.Handler for a service configuration.
|
||||||
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) {
|
||||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||||
|
|
||||||
serviceName = provider.GetQualifiedName(ctx, serviceName)
|
serviceName = provider.GetQualifiedName(ctx, serviceName)
|
||||||
|
@ -91,21 +91,21 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, respons
|
||||||
switch {
|
switch {
|
||||||
case conf.LoadBalancer != nil:
|
case conf.LoadBalancer != nil:
|
||||||
var err error
|
var err error
|
||||||
lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
|
lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conf.AddError(err, true)
|
conf.AddError(err, true)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case conf.Weighted != nil:
|
case conf.Weighted != nil:
|
||||||
var err error
|
var err error
|
||||||
lb, err = m.getWRRServiceHandler(ctx, serviceName, conf.Weighted, responseModifier)
|
lb, err = m.getWRRServiceHandler(ctx, serviceName, conf.Weighted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conf.AddError(err, true)
|
conf.AddError(err, true)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case conf.Mirroring != nil:
|
case conf.Mirroring != nil:
|
||||||
var err error
|
var err error
|
||||||
lb, err = m.getMirrorServiceHandler(ctx, conf.Mirroring, responseModifier)
|
lb, err = m.getMirrorServiceHandler(ctx, conf.Mirroring)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conf.AddError(err, true)
|
conf.AddError(err, true)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -119,8 +119,8 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, respons
|
||||||
return lb, nil
|
return lb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.Mirroring, responseModifier func(*http.Response) error) (http.Handler, error) {
|
func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.Mirroring) (http.Handler, error) {
|
||||||
serviceHandler, err := m.BuildHTTP(ctx, config.Service, responseModifier)
|
serviceHandler, err := m.BuildHTTP(ctx, config.Service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -131,7 +131,7 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M
|
||||||
}
|
}
|
||||||
handler := mirror.New(serviceHandler, m.routinePool, maxBodySize)
|
handler := mirror.New(serviceHandler, m.routinePool, maxBodySize)
|
||||||
for _, mirrorConfig := range config.Mirrors {
|
for _, mirrorConfig := range config.Mirrors {
|
||||||
mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name, responseModifier)
|
mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M
|
||||||
return handler, nil
|
return handler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, config *dynamic.WeightedRoundRobin, responseModifier func(*http.Response) error) (http.Handler, error) {
|
func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, config *dynamic.WeightedRoundRobin) (http.Handler, error) {
|
||||||
// TODO Handle accesslog and metrics with multiple service name
|
// TODO Handle accesslog and metrics with multiple service name
|
||||||
if config.Sticky != nil && config.Sticky.Cookie != nil {
|
if config.Sticky != nil && config.Sticky.Cookie != nil {
|
||||||
config.Sticky.Cookie.Name = cookie.GetName(config.Sticky.Cookie.Name, serviceName)
|
config.Sticky.Cookie.Name = cookie.GetName(config.Sticky.Cookie.Name, serviceName)
|
||||||
|
@ -152,7 +152,7 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string,
|
||||||
|
|
||||||
balancer := wrr.New(config.Sticky)
|
balancer := wrr.New(config.Sticky)
|
||||||
for _, service := range config.Services {
|
for _, service := range config.Services {
|
||||||
serviceHandler, err := m.BuildHTTP(ctx, service.Name, responseModifier)
|
serviceHandler, err := m.BuildHTTP(ctx, service.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -162,18 +162,13 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string,
|
||||||
return balancer, nil
|
return balancer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) getLoadBalancerServiceHandler(
|
func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer) (http.Handler, error) {
|
||||||
ctx context.Context,
|
|
||||||
serviceName string,
|
|
||||||
service *dynamic.ServersLoadBalancer,
|
|
||||||
responseModifier func(*http.Response) error,
|
|
||||||
) (http.Handler, error) {
|
|
||||||
if service.PassHostHeader == nil {
|
if service.PassHostHeader == nil {
|
||||||
defaultPassHostHeader := true
|
defaultPassHostHeader := true
|
||||||
service.PassHostHeader = &defaultPassHostHeader
|
service.PassHostHeader = &defaultPassHostHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
|
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -259,7 +259,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
||||||
for _, test := range testCases {
|
for _, test := range testCases {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier)
|
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, handler)
|
assert.NotNil(t, handler)
|
||||||
|
@ -339,7 +339,7 @@ func TestManager_Build(t *testing.T) {
|
||||||
ctx = provider.AddInContext(ctx, "foobar@"+test.providerName)
|
ctx = provider.AddInContext(ctx, "foobar@"+test.providerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
|
_, err := manager.BuildHTTP(ctx, test.serviceName)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -357,7 +357,7 @@ func TestMultipleTypeOnBuildHTTP(t *testing.T) {
|
||||||
|
|
||||||
manager := NewManager(services, http.DefaultTransport, nil, nil)
|
manager := NewManager(services, http.DefaultTransport, nil, nil)
|
||||||
|
|
||||||
_, err := manager.BuildHTTP(context.Background(), "test@file", nil)
|
_, err := manager.BuildHTTP(context.Background(), "test@file")
|
||||||
assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead")
|
assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue