From 52790d3c379058bb41929695748a7a810b7a3ec6 Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Tue, 1 Sep 2020 18:16:04 +0200 Subject: [PATCH] Headers response modifier is directly applied by headers middleware Co-authored-by: Ludovic Fernandez --- .../fixtures/headers/secure_multiple.toml | 33 + integration/headers_test.go | 41 ++ pkg/middlewares/customerrors/custom_errors.go | 4 +- .../customerrors/custom_errors_test.go | 2 +- pkg/middlewares/headers/header.go | 170 ++++++ pkg/middlewares/headers/header_test.go | 492 +++++++++++++++ pkg/middlewares/headers/headers.go | 205 +------ pkg/middlewares/headers/headers_test.go | 575 ++---------------- pkg/middlewares/headers/responsewriter.go | 75 +++ pkg/middlewares/headers/secure.go | 54 ++ pkg/middlewares/headers/secure_test.go | 191 ++++++ pkg/responsemodifiers/headers.go | 53 -- pkg/responsemodifiers/log.go | 13 - pkg/responsemodifiers/response_modifier.go | 68 --- .../response_modifier_test.go | 214 ------- pkg/server/middleware/middlewares.go | 2 +- pkg/server/router/router.go | 18 +- pkg/server/router/router_test.go | 18 +- pkg/server/routerfactory.go | 4 +- pkg/server/service/internalhandler.go | 80 +-- pkg/server/service/proxy.go | 9 +- pkg/server/service/proxy_test.go | 2 +- pkg/server/service/proxy_websocket_test.go | 28 +- pkg/server/service/service.go | 27 +- pkg/server/service/service_test.go | 6 +- 25 files changed, 1144 insertions(+), 1240 deletions(-) create mode 100644 integration/fixtures/headers/secure_multiple.toml create mode 100644 pkg/middlewares/headers/header.go create mode 100644 pkg/middlewares/headers/header_test.go create mode 100644 pkg/middlewares/headers/responsewriter.go create mode 100644 pkg/middlewares/headers/secure.go create mode 100644 pkg/middlewares/headers/secure_test.go delete mode 100644 pkg/responsemodifiers/headers.go delete mode 100644 pkg/responsemodifiers/log.go delete mode 100644 pkg/responsemodifiers/response_modifier.go delete mode 100644 pkg/responsemodifiers/response_modifier_test.go diff --git a/integration/fixtures/headers/secure_multiple.toml b/integration/fixtures/headers/secure_multiple.toml new file mode 100644 index 000000000..9603046e5 --- /dev/null +++ b/integration/fixtures/headers/secure_multiple.toml @@ -0,0 +1,33 @@ +[global] + checkNewVersion = false + sendAnonymousUsage = false + +[log] + level = "DEBUG" + +[entryPoints] + [entryPoints.web] + address = ":8000" + +[providers.file] + filename = "{{ .SelfFilename }}" + +## dynamic configuration ## + +[http.routers] + [http.routers.router1] + rule = "Host(`test.localhost`)" + middlewares = ["foo", "bar"] + service = "service1" + + +[http.middlewares] + [http.middlewares.foo.headers] + frameDeny = true + [http.middlewares.bar.headers] + contentTypeNosniff = true + +[http.services] + [http.services.service1.loadBalancer] + [[http.services.service1.loadBalancer.servers]] + url = "http://127.0.0.1:9000" diff --git a/integration/headers_test.go b/integration/headers_test.go index a5cc4259a..c464bf97c 100644 --- a/integration/headers_test.go +++ b/integration/headers_test.go @@ -162,3 +162,44 @@ func (s *HeadersSuite) TestSecureHeadersResponses(c *check.C) { c.Assert(err, checker.IsNil) } } + +func (s *HeadersSuite) TestMultipleSecureHeadersResponses(c *check.C) { + file := s.adaptFile(c, "fixtures/headers/secure_multiple.toml", struct{}{}) + defer os.Remove(file) + cmd, display := s.traefikCmd(withConfigFile(file)) + defer display(c) + + err := cmd.Start() + c.Assert(err, checker.IsNil) + defer cmd.Process.Kill() + + backend := startTestServer("9000", http.StatusOK, "") + defer backend.Close() + + err = try.GetRequest(backend.URL, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK)) + c.Assert(err, checker.IsNil) + + testCase := []struct { + desc string + expected http.Header + reqHost string + }{ + { + desc: "Feature-Policy Set", + expected: http.Header{ + "X-Frame-Options": {"DENY"}, + "X-Content-Type-Options": {"nosniff"}, + }, + reqHost: "test.localhost", + }, + } + + for _, test := range testCase { + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/", nil) + c.Assert(err, checker.IsNil) + req.Host = test.reqHost + + err = try.Request(req, 500*time.Millisecond, try.HasHeaderStruct(test.expected)) + c.Assert(err, checker.IsNil) + } +} diff --git a/pkg/middlewares/customerrors/custom_errors.go b/pkg/middlewares/customerrors/custom_errors.go index d5f687fde..3789142e5 100644 --- a/pkg/middlewares/customerrors/custom_errors.go +++ b/pkg/middlewares/customerrors/custom_errors.go @@ -33,7 +33,7 @@ const ( ) type serviceBuilder interface { - BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) + BuildHTTP(ctx context.Context, serviceName string) (http.Handler, error) } // customErrors is a middleware that provides the custom error pages.. @@ -54,7 +54,7 @@ func New(ctx context.Context, next http.Handler, config dynamic.ErrorPage, servi return nil, err } - backend, err := serviceBuilder.BuildHTTP(ctx, config.Service, nil) + backend, err := serviceBuilder.BuildHTTP(ctx, config.Service) if err != nil { return nil, err } diff --git a/pkg/middlewares/customerrors/custom_errors_test.go b/pkg/middlewares/customerrors/custom_errors_test.go index 58a3a1673..0a608cc44 100644 --- a/pkg/middlewares/customerrors/custom_errors_test.go +++ b/pkg/middlewares/customerrors/custom_errors_test.go @@ -150,7 +150,7 @@ type mockServiceBuilder struct { handler http.Handler } -func (m *mockServiceBuilder) BuildHTTP(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) { +func (m *mockServiceBuilder) BuildHTTP(ctx context.Context, serviceName string) (http.Handler, error) { return m.handler, nil } diff --git a/pkg/middlewares/headers/header.go b/pkg/middlewares/headers/header.go new file mode 100644 index 000000000..862b376ed --- /dev/null +++ b/pkg/middlewares/headers/header.go @@ -0,0 +1,170 @@ +package headers + +import ( + "context" + "net/http" + "strconv" + "strings" + + "github.com/containous/traefik/v2/pkg/config/dynamic" + "github.com/containous/traefik/v2/pkg/log" +) + +// Header is a middleware that helps setup a few basic security features. +// A single headerOptions struct can be provided to configure which features should be enabled, +// and the ability to override a few of the default values. +type Header struct { + next http.Handler + hasCustomHeaders bool + hasCorsHeaders bool + headers *dynamic.Headers +} + +// NewHeader constructs a new header instance from supplied frontend header struct. +func NewHeader(next http.Handler, cfg dynamic.Headers) *Header { + hasCustomHeaders := cfg.HasCustomHeadersDefined() + hasCorsHeaders := cfg.HasCorsHeadersDefined() + + ctx := log.With(context.Background(), log.Str(log.MiddlewareType, typeName)) + handleDeprecation(ctx, &cfg) + + return &Header{ + next: next, + headers: &cfg, + hasCustomHeaders: hasCustomHeaders, + hasCorsHeaders: hasCorsHeaders, + } +} + +func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // Handle Cors headers and preflight if configured. + if isPreflight := s.processCorsHeaders(rw, req); isPreflight { + return + } + + if s.hasCustomHeaders { + s.modifyCustomRequestHeaders(req) + } + + // If there is a next, call it. + if s.next != nil { + s.next.ServeHTTP(newResponseModifier(rw, req, s.PostRequestModifyResponseHeaders), req) + } +} + +// modifyCustomRequestHeaders sets or deletes custom request headers. +func (s *Header) modifyCustomRequestHeaders(req *http.Request) { + // Loop through Custom request headers + for header, value := range s.headers.CustomRequestHeaders { + switch { + case value == "": + req.Header.Del(header) + + case strings.EqualFold(header, "Host"): + req.Host = value + + default: + req.Header.Set(header, value) + } + } +} + +// PostRequestModifyResponseHeaders set or delete response headers. +// This method is called AFTER the response is generated from the backend +// and can merge/override headers from the backend response. +func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error { + // Loop through Custom response headers + for header, value := range s.headers.CustomResponseHeaders { + if value == "" { + res.Header.Del(header) + } else { + res.Header.Set(header, value) + } + } + + if res != nil && res.Request != nil { + originHeader := res.Request.Header.Get("Origin") + allowed, match := s.isOriginAllowed(originHeader) + + if allowed { + res.Header.Set("Access-Control-Allow-Origin", match) + } + } + + if s.headers.AccessControlAllowCredentials { + res.Header.Set("Access-Control-Allow-Credentials", "true") + } + + if len(s.headers.AccessControlExposeHeaders) > 0 { + exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",") + res.Header.Set("Access-Control-Expose-Headers", exposeHeaders) + } + + if !s.headers.AddVaryHeader { + return nil + } + + varyHeader := res.Header.Get("Vary") + if varyHeader == "Origin" { + return nil + } + + if varyHeader != "" { + varyHeader += "," + } + varyHeader += "Origin" + + res.Header.Set("Vary", varyHeader) + return nil +} + +// processCorsHeaders processes the incoming request, +// and returns if it is a preflight request. +// If not a preflight, it handles the preRequestModifyCorsResponseHeaders. +func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool { + if !s.hasCorsHeaders { + return false + } + + reqAcMethod := req.Header.Get("Access-Control-Request-Method") + originHeader := req.Header.Get("Origin") + + if reqAcMethod != "" && originHeader != "" && req.Method == http.MethodOptions { + // If the request is an OPTIONS request with an Access-Control-Request-Method header, + // and Origin headers, then it is a CORS preflight request, + // and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request + if s.headers.AccessControlAllowCredentials { + rw.Header().Set("Access-Control-Allow-Credentials", "true") + } + + allowHeaders := strings.Join(s.headers.AccessControlAllowHeaders, ",") + if allowHeaders != "" { + rw.Header().Set("Access-Control-Allow-Headers", allowHeaders) + } + + allowMethods := strings.Join(s.headers.AccessControlAllowMethods, ",") + if allowMethods != "" { + rw.Header().Set("Access-Control-Allow-Methods", allowMethods) + } + + allowed, match := s.isOriginAllowed(originHeader) + if allowed { + rw.Header().Set("Access-Control-Allow-Origin", match) + } + + rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge))) + return true + } + + return false +} + +func (s *Header) isOriginAllowed(origin string) (bool, string) { + for _, item := range s.headers.AccessControlAllowOriginList { + if item == "*" || item == origin { + return true, item + } + } + + return false, "" +} diff --git a/pkg/middlewares/headers/header_test.go b/pkg/middlewares/headers/header_test.go new file mode 100644 index 000000000..f0aa43218 --- /dev/null +++ b/pkg/middlewares/headers/header_test.go @@ -0,0 +1,492 @@ +package headers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/containous/traefik/v2/pkg/config/dynamic" + "github.com/stretchr/testify/assert" +) + +func TestNewHeader_customRequestHeader(t *testing.T) { + testCases := []struct { + desc string + cfg dynamic.Headers + expected http.Header + }{ + { + desc: "adds a header", + cfg: dynamic.Headers{ + CustomRequestHeaders: map[string]string{ + "X-Custom-Request-Header": "test_request", + }, + }, + expected: http.Header{"Foo": []string{"bar"}, "X-Custom-Request-Header": []string{"test_request"}}, + }, + { + desc: "delete a header", + cfg: dynamic.Headers{ + CustomRequestHeaders: map[string]string{ + "X-Custom-Request-Header": "", + "Foo": "", + }, + }, + expected: http.Header{}, + }, + { + desc: "override a header", + cfg: dynamic.Headers{ + CustomRequestHeaders: map[string]string{ + "Foo": "test", + }, + }, + expected: http.Header{"Foo": []string{"test"}}, + }, + } + + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + mid := NewHeader(emptyHandler, test.cfg) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header.Set("Foo", "bar") + + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, test.expected, req.Header) + }) + } +} + +func TestNewHeader_customRequestHeader_Host(t *testing.T) { + testCases := []struct { + desc string + customHeaders map[string]string + expectedHost string + expectedURLHost string + }{ + { + desc: "standard Host header", + customHeaders: map[string]string{}, + expectedHost: "example.org", + expectedURLHost: "example.org", + }, + { + desc: "custom Host header", + customHeaders: map[string]string{ + "Host": "example.com", + }, + expectedHost: "example.com", + expectedURLHost: "example.org", + }, + } + + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + mid := NewHeader(emptyHandler, dynamic.Headers{CustomRequestHeaders: test.customHeaders}) + + req := httptest.NewRequest(http.MethodGet, "http://example.org/foo", nil) + + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, test.expectedHost, req.Host) + assert.Equal(t, test.expectedURLHost, req.URL.Host) + }) + } +} + +func TestNewHeader_CORSPreflights(t *testing.T) { + testCases := []struct { + desc string + cfg dynamic.Headers + requestHeaders http.Header + expected http.Header + }{ + { + desc: "Test Simple Preflight", + cfg: dynamic.Headers{ + AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, + AccessControlAllowOriginList: []string{"https://foo.bar.org"}, + AccessControlMaxAge: 600, + }, + requestHeaders: map[string][]string{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"https://foo.bar.org"}, + "Access-Control-Max-Age": {"600"}, + "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, + }, + }, + { + desc: "Wildcard origin Preflight", + cfg: dynamic.Headers{ + AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, + AccessControlAllowOriginList: []string{"*"}, + AccessControlMaxAge: 600, + }, + requestHeaders: map[string][]string{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Max-Age": {"600"}, + "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, + }, + }, + { + desc: "Allow Credentials Preflight", + cfg: dynamic.Headers{ + AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, + AccessControlAllowOriginList: []string{"*"}, + AccessControlAllowCredentials: true, + AccessControlMaxAge: 600, + }, + requestHeaders: map[string][]string{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Max-Age": {"600"}, + "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, + "Access-Control-Allow-Credentials": {"true"}, + }, + }, + { + desc: "Allow Headers Preflight", + cfg: dynamic.Headers{ + AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, + AccessControlAllowOriginList: []string{"*"}, + AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"}, + AccessControlMaxAge: 600, + }, + requestHeaders: map[string][]string{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Max-Age": {"600"}, + "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, + "Access-Control-Allow-Headers": {"origin,X-Forwarded-For"}, + }, + }, + { + desc: "No Request Headers Preflight", + cfg: dynamic.Headers{ + AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, + AccessControlAllowOriginList: []string{"*"}, + AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"}, + AccessControlMaxAge: 600, + }, + requestHeaders: map[string][]string{ + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Max-Age": {"600"}, + "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, + "Access-Control-Allow-Headers": {"origin,X-Forwarded-For"}, + }, + }, + } + + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + mid := NewHeader(emptyHandler, test.cfg) + + req := httptest.NewRequest(http.MethodOptions, "/foo", nil) + req.Header = test.requestHeaders + + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, test.expected, rw.Result().Header) + }) + } +} + +func TestNewHeader_CORSResponses(t *testing.T) { + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + + testCases := []struct { + desc string + next http.Handler + cfg dynamic.Headers + requestHeaders http.Header + expected http.Header + }{ + { + desc: "Test Simple Request", + next: emptyHandler, + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"https://foo.bar.org"}, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"https://foo.bar.org"}, + }, + }, + { + desc: "Wildcard origin Request", + next: emptyHandler, + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"*"}, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + }, + }, + { + desc: "Empty origin Request", + next: emptyHandler, + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"https://foo.bar.org"}, + }, + requestHeaders: map[string][]string{}, + expected: map[string][]string{}, + }, + { + desc: "Not Defined origin Request", + next: emptyHandler, + requestHeaders: map[string][]string{}, + expected: map[string][]string{}, + }, + { + desc: "Allow Credentials Request", + next: emptyHandler, + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"*"}, + AccessControlAllowCredentials: true, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Allow-Credentials": {"true"}, + }, + }, + { + desc: "Expose Headers Request", + next: emptyHandler, + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"*"}, + AccessControlExposeHeaders: []string{"origin", "X-Forwarded-For"}, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + "Access-Control-Expose-Headers": {"origin,X-Forwarded-For"}, + }, + }, + { + desc: "Test Simple Request with Vary Headers", + next: emptyHandler, + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"https://foo.bar.org"}, + AddVaryHeader: true, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"https://foo.bar.org"}, + "Vary": {"Origin"}, + }, + }, + { + desc: "Test Simple Request with Vary Headers and non-empty response", + next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // nonEmptyHandler + w.Header().Set("Vary", "Testing") + w.WriteHeader(http.StatusOK) + }), + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"https://foo.bar.org"}, + AddVaryHeader: true, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"https://foo.bar.org"}, + "Vary": {"Testing,Origin"}, + }, + }, + { + desc: "Test Simple Request with Vary Headers and existing vary:origin response", + next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // existingOriginHandler + w.Header().Set("Vary", "Origin") + w.WriteHeader(http.StatusOK) + }), + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"https://foo.bar.org"}, + AddVaryHeader: true, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"https://foo.bar.org"}, + "Vary": {"Origin"}, + }, + }, + { + desc: "Test Simple Request with non-empty response: set ACAO", + next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // existingAccessControlAllowOriginHandlerSet + w.Header().Set("Access-Control-Allow-Origin", "http://foo.bar.org") + w.WriteHeader(http.StatusOK) + }), + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"*"}, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + }, + }, + { + desc: "Test Simple Request with non-empty response: add ACAO", + next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // existingAccessControlAllowOriginHandlerAdd + w.Header().Add("Access-Control-Allow-Origin", "http://foo.bar.org") + w.WriteHeader(http.StatusOK) + }), + cfg: dynamic.Headers{ + AccessControlAllowOriginList: []string{"*"}, + }, + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"*"}, + }, + }, + { + desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS", + next: emptyHandler, + cfg: dynamic.Headers{ + CustomRequestHeaders: map[string]string{"foo": "bar"}, + }, + requestHeaders: map[string][]string{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{}, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + mid := NewHeader(test.next, test.cfg) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + req.Header = test.requestHeaders + + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, test.expected, rw.Result().Header) + }) + } +} + +func TestNewHeader_customResponseHeaders(t *testing.T) { + testCases := []struct { + desc string + config map[string]string + expected http.Header + }{ + { + desc: "Test Simple Response", + config: map[string]string{ + "Testing": "foo", + "Testing2": "bar", + }, + expected: map[string][]string{ + "Foo": {"bar"}, + "Testing": {"foo"}, + "Testing2": {"bar"}, + }, + }, + { + desc: "empty Custom Header", + config: map[string]string{ + "Testing": "foo", + "Testing2": "", + }, + expected: map[string][]string{ + "Foo": {"bar"}, + "Testing": {"foo"}, + }, + }, + { + desc: "Deleting Custom Header", + config: map[string]string{ + "Testing": "foo", + "Foo": "", + }, + expected: map[string][]string{ + "Testing": {"foo"}, + }, + }, + } + + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Foo", "bar") + w.WriteHeader(http.StatusOK) + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + mid := NewHeader(emptyHandler, dynamic.Headers{CustomResponseHeaders: test.config}) + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, test.expected, rw.Result().Header) + }) + } +} diff --git a/pkg/middlewares/headers/headers.go b/pkg/middlewares/headers/headers.go index 53cd34163..881ad1565 100644 --- a/pkg/middlewares/headers/headers.go +++ b/pkg/middlewares/headers/headers.go @@ -5,15 +5,12 @@ import ( "context" "errors" "net/http" - "strconv" - "strings" "github.com/containous/traefik/v2/pkg/config/dynamic" "github.com/containous/traefik/v2/pkg/log" "github.com/containous/traefik/v2/pkg/middlewares" "github.com/containous/traefik/v2/pkg/tracing" "github.com/opentracing/opentracing-go/ext" - "github.com/unrolled/secure" ) const ( @@ -55,7 +52,7 @@ func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name strin if hasSecureHeaders { logger.Debugf("Setting up secureHeaders from %v", cfg) - handler = newSecure(next, cfg) + handler = newSecure(next, cfg, name) nextHandler = handler } @@ -77,203 +74,3 @@ func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) { func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) { h.handler.ServeHTTP(rw, req) } - -type secureHeader struct { - next http.Handler - secure *secure.Secure -} - -// newSecure constructs a new secure instance with supplied options. -func newSecure(next http.Handler, cfg dynamic.Headers) *secureHeader { - opt := secure.Options{ - BrowserXssFilter: cfg.BrowserXSSFilter, - ContentTypeNosniff: cfg.ContentTypeNosniff, - ForceSTSHeader: cfg.ForceSTSHeader, - FrameDeny: cfg.FrameDeny, - IsDevelopment: cfg.IsDevelopment, - SSLRedirect: cfg.SSLRedirect, - SSLForceHost: cfg.SSLForceHost, - SSLTemporaryRedirect: cfg.SSLTemporaryRedirect, - STSIncludeSubdomains: cfg.STSIncludeSubdomains, - STSPreload: cfg.STSPreload, - ContentSecurityPolicy: cfg.ContentSecurityPolicy, - CustomBrowserXssValue: cfg.CustomBrowserXSSValue, - CustomFrameOptionsValue: cfg.CustomFrameOptionsValue, - PublicKey: cfg.PublicKey, - ReferrerPolicy: cfg.ReferrerPolicy, - SSLHost: cfg.SSLHost, - AllowedHosts: cfg.AllowedHosts, - HostsProxyHeaders: cfg.HostsProxyHeaders, - SSLProxyHeaders: cfg.SSLProxyHeaders, - STSSeconds: cfg.STSSeconds, - FeaturePolicy: cfg.FeaturePolicy, - } - - return &secureHeader{ - next: next, - secure: secure.New(opt), - } -} - -func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP) -} - -// Header is a middleware that helps setup a few basic security features. -// A single headerOptions struct can be provided to configure which features should be enabled, -// and the ability to override a few of the default values. -type Header struct { - next http.Handler - hasCustomHeaders bool - hasCorsHeaders bool - headers *dynamic.Headers -} - -// NewHeader constructs a new header instance from supplied frontend header struct. -func NewHeader(next http.Handler, cfg dynamic.Headers) *Header { - hasCustomHeaders := cfg.HasCustomHeadersDefined() - hasCorsHeaders := cfg.HasCorsHeadersDefined() - - ctx := log.With(context.Background(), log.Str(log.MiddlewareType, typeName)) - handleDeprecation(ctx, &cfg) - - return &Header{ - next: next, - headers: &cfg, - hasCustomHeaders: hasCustomHeaders, - hasCorsHeaders: hasCorsHeaders, - } -} - -func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - // Handle Cors headers and preflight if configured. - if isPreflight := s.processCorsHeaders(rw, req); isPreflight { - return - } - - if s.hasCustomHeaders { - s.modifyCustomRequestHeaders(req) - } - - // If there is a next, call it. - if s.next != nil { - s.next.ServeHTTP(rw, req) - } -} - -// modifyCustomRequestHeaders sets or deletes custom request headers. -func (s *Header) modifyCustomRequestHeaders(req *http.Request) { - // Loop through Custom request headers - for header, value := range s.headers.CustomRequestHeaders { - switch { - case value == "": - req.Header.Del(header) - - case strings.EqualFold(header, "Host"): - req.Host = value - - default: - req.Header.Set(header, value) - } - } -} - -// PostRequestModifyResponseHeaders set or delete response headers. -// This method is called AFTER the response is generated from the backend -// and can merge/override headers from the backend response. -func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error { - // Loop through Custom response headers - for header, value := range s.headers.CustomResponseHeaders { - if value == "" { - res.Header.Del(header) - } else { - res.Header.Set(header, value) - } - } - - if res != nil && res.Request != nil { - originHeader := res.Request.Header.Get("Origin") - allowed, match := s.isOriginAllowed(originHeader) - - if allowed { - res.Header.Set("Access-Control-Allow-Origin", match) - } - } - - if s.headers.AccessControlAllowCredentials { - res.Header.Set("Access-Control-Allow-Credentials", "true") - } - - if len(s.headers.AccessControlExposeHeaders) > 0 { - exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",") - res.Header.Set("Access-Control-Expose-Headers", exposeHeaders) - } - - if !s.headers.AddVaryHeader { - return nil - } - - varyHeader := res.Header.Get("Vary") - if varyHeader == "Origin" { - return nil - } - - if varyHeader != "" { - varyHeader += "," - } - varyHeader += "Origin" - - res.Header.Set("Vary", varyHeader) - return nil -} - -// processCorsHeaders processes the incoming request, -// and returns if it is a preflight request. -// If not a preflight, it handles the preRequestModifyCorsResponseHeaders. -func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool { - if !s.hasCorsHeaders { - return false - } - - reqAcMethod := req.Header.Get("Access-Control-Request-Method") - originHeader := req.Header.Get("Origin") - - if reqAcMethod != "" && originHeader != "" && req.Method == http.MethodOptions { - // If the request is an OPTIONS request with an Access-Control-Request-Method header, - // and Origin headers, then it is a CORS preflight request, - // and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request - if s.headers.AccessControlAllowCredentials { - rw.Header().Set("Access-Control-Allow-Credentials", "true") - } - - allowHeaders := strings.Join(s.headers.AccessControlAllowHeaders, ",") - if allowHeaders != "" { - rw.Header().Set("Access-Control-Allow-Headers", allowHeaders) - } - - allowMethods := strings.Join(s.headers.AccessControlAllowMethods, ",") - if allowMethods != "" { - rw.Header().Set("Access-Control-Allow-Methods", allowMethods) - } - - allowed, match := s.isOriginAllowed(originHeader) - if allowed { - rw.Header().Set("Access-Control-Allow-Origin", match) - } - - rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge))) - return true - } - - return false -} - -func (s *Header) isOriginAllowed(origin string) (bool, string) { - for _, item := range s.headers.AccessControlAllowOriginList { - if item == "*" || item == origin { - return true, item - } - } - - return false, "" -} diff --git a/pkg/middlewares/headers/headers_test.go b/pkg/middlewares/headers/headers_test.go index eef384a5c..a7be06a97 100644 --- a/pkg/middlewares/headers/headers_test.go +++ b/pkg/middlewares/headers/headers_test.go @@ -9,104 +9,21 @@ import ( "testing" "github.com/containous/traefik/v2/pkg/config/dynamic" - "github.com/containous/traefik/v2/pkg/testhelpers" "github.com/containous/traefik/v2/pkg/tracing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestCustomRequestHeader(t *testing.T) { - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +func TestNew_withoutOptions(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - header := NewHeader(emptyHandler, dynamic.Headers{ - CustomRequestHeaders: map[string]string{ - "X-Custom-Request-Header": "test_request", - }, - }) + mid, err := New(context.Background(), next, dynamic.Headers{}, "testing") + require.Errorf(t, err, "headers configuration not valid") - res := httptest.NewRecorder() - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) - - header.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header")) + assert.Nil(t, mid) } -func TestCustomRequestHeader_Host(t *testing.T) { - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - - testCases := []struct { - desc string - customHeaders map[string]string - expectedHost string - expectedURLHost string - }{ - { - desc: "standard Host header", - customHeaders: map[string]string{}, - expectedHost: "example.org", - expectedURLHost: "example.org", - }, - { - desc: "custom Host header", - customHeaders: map[string]string{ - "Host": "example.com", - }, - expectedHost: "example.com", - expectedURLHost: "example.org", - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - header := NewHeader(emptyHandler, dynamic.Headers{ - CustomRequestHeaders: test.customHeaders, - }) - - res := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "http://example.org/foo", nil) - require.NoError(t, err) - - header.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, test.expectedHost, req.Host) - assert.Equal(t, test.expectedURLHost, req.URL.Host) - }) - } -} - -func TestCustomRequestHeaderEmptyValue(t *testing.T) { - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - - header := NewHeader(emptyHandler, dynamic.Headers{ - CustomRequestHeaders: map[string]string{ - "X-Custom-Request-Header": "test_request", - }, - }) - - res := httptest.NewRecorder() - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) - - header.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header")) - - header = NewHeader(emptyHandler, dynamic.Headers{ - CustomRequestHeaders: map[string]string{ - "X-Custom-Request-Header": "", - }, - }) - - header.ServeHTTP(res, req) - - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header")) -} - -func TestSecureHeader(t *testing.T) { +func TestNew_allowedHosts(t *testing.T) { testCases := []struct { desc string fromHost string @@ -129,10 +46,13 @@ func TestSecureHeader(t *testing.T) { }, } - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - header, err := New(context.Background(), emptyHandler, dynamic.Headers{ + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + + cfg := dynamic.Headers{ AllowedHosts: []string{"foo.com", "bar.com"}, - }, "foo") + } + + mid, err := New(context.Background(), emptyHandler, cfg, "foo") require.NoError(t, err) for _, test := range testCases { @@ -140,479 +60,54 @@ func TestSecureHeader(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - res := httptest.NewRecorder() - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Host = test.fromHost - header.ServeHTTP(res, req) - assert.Equal(t, test.expected, res.Code) - }) - } -} - -func TestSSLForceHost(t *testing.T) { - next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - _, _ = rw.Write([]byte("OK")) - }) - - testCases := []struct { - desc string - host string - secureMiddleware *secureHeader - expected int - }{ - { - desc: "http should return a 301", - host: "http://powpow.example.com", - secureMiddleware: newSecure(next, dynamic.Headers{ - SSLRedirect: true, - SSLForceHost: true, - SSLHost: "powpow.example.com", - }), - expected: http.StatusMovedPermanently, - }, - { - desc: "http sub domain should return a 301", - host: "http://www.powpow.example.com", - secureMiddleware: newSecure(next, dynamic.Headers{ - SSLRedirect: true, - SSLForceHost: true, - SSLHost: "powpow.example.com", - }), - expected: http.StatusMovedPermanently, - }, - { - desc: "https should return a 200", - host: "https://powpow.example.com", - secureMiddleware: newSecure(next, dynamic.Headers{ - SSLRedirect: true, - SSLForceHost: true, - SSLHost: "powpow.example.com", - }), - expected: http.StatusOK, - }, - { - desc: "https sub domain should return a 301", - host: "https://www.powpow.example.com", - secureMiddleware: newSecure(next, dynamic.Headers{ - SSLRedirect: true, - SSLForceHost: true, - SSLHost: "powpow.example.com", - }), - expected: http.StatusMovedPermanently, - }, - { - desc: "http without force host and sub domain should return a 301", - host: "http://www.powpow.example.com", - secureMiddleware: newSecure(next, dynamic.Headers{ - SSLRedirect: true, - SSLForceHost: false, - SSLHost: "powpow.example.com", - }), - expected: http.StatusMovedPermanently, - }, - { - desc: "https without force host and sub domain should return a 301", - host: "https://www.powpow.example.com", - secureMiddleware: newSecure(next, dynamic.Headers{ - SSLRedirect: true, - SSLForceHost: false, - SSLHost: "powpow.example.com", - }), - expected: http.StatusOK, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - req := testhelpers.MustNewRequest(http.MethodGet, test.host, nil) rw := httptest.NewRecorder() - test.secureMiddleware.ServeHTTP(rw, req) - assert.Equal(t, test.expected, rw.Result().StatusCode) + mid.ServeHTTP(rw, req) + + assert.Equal(t, test.expected, rw.Code) }) } } -func TestCORSPreflights(t *testing.T) { - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +func TestNew_customHeaders(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) - testCases := []struct { - desc string - header *Header - requestHeaders http.Header - expected http.Header - }{ - { - desc: "Test Simple Preflight", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, - AccessControlAllowOriginList: []string{"https://foo.bar.org"}, - AccessControlMaxAge: 600, - }), - requestHeaders: map[string][]string{ - "Access-Control-Request-Headers": {"origin"}, - "Access-Control-Request-Method": {"GET", "OPTIONS"}, - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"https://foo.bar.org"}, - "Access-Control-Max-Age": {"600"}, - "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, - }, - }, - { - desc: "Wildcard origin Preflight", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, - AccessControlAllowOriginList: []string{"*"}, - AccessControlMaxAge: 600, - }), - requestHeaders: map[string][]string{ - "Access-Control-Request-Headers": {"origin"}, - "Access-Control-Request-Method": {"GET", "OPTIONS"}, - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - "Access-Control-Max-Age": {"600"}, - "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, - }, - }, - { - desc: "Allow Credentials Preflight", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, - AccessControlAllowOriginList: []string{"*"}, - AccessControlAllowCredentials: true, - AccessControlMaxAge: 600, - }), - requestHeaders: map[string][]string{ - "Access-Control-Request-Headers": {"origin"}, - "Access-Control-Request-Method": {"GET", "OPTIONS"}, - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - "Access-Control-Max-Age": {"600"}, - "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, - "Access-Control-Allow-Credentials": {"true"}, - }, - }, - { - desc: "Allow Headers Preflight", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, - AccessControlAllowOriginList: []string{"*"}, - AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"}, - AccessControlMaxAge: 600, - }), - requestHeaders: map[string][]string{ - "Access-Control-Request-Headers": {"origin"}, - "Access-Control-Request-Method": {"GET", "OPTIONS"}, - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - "Access-Control-Max-Age": {"600"}, - "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, - "Access-Control-Allow-Headers": {"origin,X-Forwarded-For"}, - }, - }, - { - desc: "No Request Headers Preflight", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowMethods: []string{"GET", "OPTIONS", "PUT"}, - AccessControlAllowOriginList: []string{"*"}, - AccessControlAllowHeaders: []string{"origin", "X-Forwarded-For"}, - AccessControlMaxAge: 600, - }), - requestHeaders: map[string][]string{ - "Access-Control-Request-Method": {"GET", "OPTIONS"}, - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - "Access-Control-Max-Age": {"600"}, - "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, - "Access-Control-Allow-Headers": {"origin,X-Forwarded-For"}, - }, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - req := testhelpers.MustNewRequest(http.MethodOptions, "/foo", nil) - req.Header = test.requestHeaders - - rw := httptest.NewRecorder() - test.header.ServeHTTP(rw, req) - - assert.Equal(t, test.expected, rw.Result().Header) - }) - } -} - -func TestEmptyHeaderObject(t *testing.T) { - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - - _, err := New(context.Background(), next, dynamic.Headers{}, "testing") - require.Errorf(t, err, "headers configuration not valid") -} - -func TestCustomHeaderHandler(t *testing.T) { - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - - header, _ := New(context.Background(), next, dynamic.Headers{ + cfg := dynamic.Headers{ CustomRequestHeaders: map[string]string{ "X-Custom-Request-Header": "test_request", }, - }, "testing") + CustomResponseHeaders: map[string]string{ + "X-Custom-Response-Header": "test_response", + }, + } - res := httptest.NewRecorder() - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) + mid, err := New(context.Background(), next, cfg, "testing") + require.NoError(t, err) - header.ServeHTTP(res, req) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) - assert.Equal(t, http.StatusOK, res.Code) + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusOK, rw.Code) assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header")) + assert.Equal(t, "test_response", rw.Header().Get("X-Custom-Response-Header")) } -func TestGetTracingInformation(t *testing.T) { +func Test_headers_getTracingInformation(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - header := &headers{ + mid := &headers{ handler: next, name: "testing", } - name, trace := header.GetTracingInformation() + name, trace := mid.GetTracingInformation() assert.Equal(t, "testing", name) assert.Equal(t, tracing.SpanKindNoneEnum, trace) } - -func TestCORSResponses(t *testing.T) { - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") }) - existingOriginHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Origin") }) - existingAccessControlAllowOriginHandlerSet := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "http://foo.bar.org") - }) - existingAccessControlAllowOriginHandlerAdd := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Access-Control-Allow-Origin", "http://foo.bar.org") - }) - - testCases := []struct { - desc string - header *Header - requestHeaders http.Header - expected http.Header - }{ - { - desc: "Test Simple Request", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"https://foo.bar.org"}, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"https://foo.bar.org"}, - }, - }, - { - desc: "Wildcard origin Request", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"*"}, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - }, - }, - { - desc: "Empty origin Request", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"https://foo.bar.org"}, - }), - requestHeaders: map[string][]string{}, - expected: map[string][]string{}, - }, - { - desc: "Not Defined origin Request", - header: NewHeader(emptyHandler, dynamic.Headers{}), - requestHeaders: map[string][]string{}, - expected: map[string][]string{}, - }, - { - desc: "Allow Credentials Request", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"*"}, - AccessControlAllowCredentials: true, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - "Access-Control-Allow-Credentials": {"true"}, - }, - }, - { - desc: "Expose Headers Request", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"*"}, - AccessControlExposeHeaders: []string{"origin", "X-Forwarded-For"}, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - "Access-Control-Expose-Headers": {"origin,X-Forwarded-For"}, - }, - }, - { - desc: "Test Simple Request with Vary Headers", - header: NewHeader(emptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"https://foo.bar.org"}, - AddVaryHeader: true, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"https://foo.bar.org"}, - "Vary": {"Origin"}, - }, - }, - { - desc: "Test Simple Request with Vary Headers and non-empty response", - header: NewHeader(nonEmptyHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"https://foo.bar.org"}, - AddVaryHeader: true, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"https://foo.bar.org"}, - "Vary": {"Testing,Origin"}, - }, - }, - { - desc: "Test Simple Request with Vary Headers and existing vary:origin response", - header: NewHeader(existingOriginHandler, dynamic.Headers{ - AccessControlAllowOriginList: []string{"https://foo.bar.org"}, - AddVaryHeader: true, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"https://foo.bar.org"}, - "Vary": {"Origin"}, - }, - }, - { - desc: "Test Simple Request with non-empty response: set ACAO", - header: NewHeader(existingAccessControlAllowOriginHandlerSet, dynamic.Headers{ - AccessControlAllowOriginList: []string{"*"}, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - }, - }, - { - desc: "Test Simple Request with non-empty response: add ACAO", - header: NewHeader(existingAccessControlAllowOriginHandlerAdd, dynamic.Headers{ - AccessControlAllowOriginList: []string{"*"}, - }), - requestHeaders: map[string][]string{ - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{ - "Access-Control-Allow-Origin": {"*"}, - }, - }, - { - desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS", - header: NewHeader(emptyHandler, dynamic.Headers{ - CustomRequestHeaders: map[string]string{"foo": "bar"}, - }), - requestHeaders: map[string][]string{ - "Access-Control-Request-Headers": {"origin"}, - "Access-Control-Request-Method": {"GET", "OPTIONS"}, - "Origin": {"https://foo.bar.org"}, - }, - expected: map[string][]string{}, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) - req.Header = test.requestHeaders - rw := httptest.NewRecorder() - test.header.ServeHTTP(rw, req) - res := rw.Result() - res.Request = req - err := test.header.PostRequestModifyResponseHeaders(res) - require.NoError(t, err) - assert.Equal(t, test.expected, rw.Result().Header) - }) - } -} - -func TestCustomResponseHeaders(t *testing.T) { - emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - - testCases := []struct { - desc string - header *Header - expected http.Header - }{ - { - desc: "Test Simple Response", - header: NewHeader(emptyHandler, dynamic.Headers{ - CustomResponseHeaders: map[string]string{ - "Testing": "foo", - "Testing2": "bar", - }, - }), - expected: map[string][]string{ - "Testing": {"foo"}, - "Testing2": {"bar"}, - }, - }, - { - desc: "Deleting Custom Header", - header: NewHeader(emptyHandler, dynamic.Headers{ - CustomResponseHeaders: map[string]string{ - "Testing": "foo", - "Testing2": "", - }, - }), - expected: map[string][]string{ - "Testing": {"foo"}, - }, - }, - } - - for _, test := range testCases { - t.Run(test.desc, func(t *testing.T) { - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) - rw := httptest.NewRecorder() - test.header.ServeHTTP(rw, req) - err := test.header.PostRequestModifyResponseHeaders(rw.Result()) - require.NoError(t, err) - assert.Equal(t, test.expected, rw.Result().Header) - }) - } -} diff --git a/pkg/middlewares/headers/responsewriter.go b/pkg/middlewares/headers/responsewriter.go new file mode 100644 index 000000000..274a58181 --- /dev/null +++ b/pkg/middlewares/headers/responsewriter.go @@ -0,0 +1,75 @@ +package headers + +import ( + "net/http" + + "github.com/containous/traefik/v2/pkg/log" +) + +type responseModifier struct { + r *http.Request + w http.ResponseWriter + + headersSent bool // whether headers have already been sent + code int // status code, must default to 200 + + modifier func(*http.Response) error // can be nil + modified bool // whether modifier has already been called for the current request + modifierErr error // returned by modifier call +} + +// modifier can be nil. +func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier { + return &responseModifier{ + r: r, + w: w, + modifier: modifier, + code: http.StatusOK, + } +} + +func (w *responseModifier) WriteHeader(code int) { + if w.headersSent { + return + } + defer func() { + w.code = code + w.headersSent = true + }() + + if w.modifier == nil || w.modified { + w.w.WriteHeader(code) + return + } + + resp := http.Response{ + Header: w.w.Header(), + Request: w.r, + } + + if err := w.modifier(&resp); err != nil { + w.modifierErr = err + // we are propagating when we are called in Write, but we're logging anyway, + // because we could be called from another place which does not take care of + // checking w.modifierErr. + log.WithoutContext().Errorf("Error when applying response modifier: %v", err) + w.w.WriteHeader(http.StatusInternalServerError) + return + } + + w.modified = true + w.w.WriteHeader(code) +} + +func (w *responseModifier) Header() http.Header { + return w.w.Header() +} + +func (w *responseModifier) Write(b []byte) (int, error) { + w.WriteHeader(w.code) + if w.modifierErr != nil { + return 0, w.modifierErr + } + + return w.w.Write(b) +} diff --git a/pkg/middlewares/headers/secure.go b/pkg/middlewares/headers/secure.go new file mode 100644 index 000000000..92cbe20d7 --- /dev/null +++ b/pkg/middlewares/headers/secure.go @@ -0,0 +1,54 @@ +package headers + +import ( + "net/http" + + "github.com/containous/traefik/v2/pkg/config/dynamic" + "github.com/unrolled/secure" +) + +type secureHeader struct { + next http.Handler + secure *secure.Secure + cfg dynamic.Headers +} + +// newSecure constructs a new secure instance with supplied options. +func newSecure(next http.Handler, cfg dynamic.Headers, contextKey string) *secureHeader { + opt := secure.Options{ + BrowserXssFilter: cfg.BrowserXSSFilter, + ContentTypeNosniff: cfg.ContentTypeNosniff, + ForceSTSHeader: cfg.ForceSTSHeader, + FrameDeny: cfg.FrameDeny, + IsDevelopment: cfg.IsDevelopment, + SSLRedirect: cfg.SSLRedirect, + SSLForceHost: cfg.SSLForceHost, + SSLTemporaryRedirect: cfg.SSLTemporaryRedirect, + STSIncludeSubdomains: cfg.STSIncludeSubdomains, + STSPreload: cfg.STSPreload, + ContentSecurityPolicy: cfg.ContentSecurityPolicy, + CustomBrowserXssValue: cfg.CustomBrowserXSSValue, + CustomFrameOptionsValue: cfg.CustomFrameOptionsValue, + PublicKey: cfg.PublicKey, + ReferrerPolicy: cfg.ReferrerPolicy, + SSLHost: cfg.SSLHost, + AllowedHosts: cfg.AllowedHosts, + HostsProxyHeaders: cfg.HostsProxyHeaders, + SSLProxyHeaders: cfg.SSLProxyHeaders, + STSSeconds: cfg.STSSeconds, + FeaturePolicy: cfg.FeaturePolicy, + SecureContextKey: contextKey, + } + + return &secureHeader{ + next: next, + secure: secure.New(opt), + cfg: cfg, + } +} + +func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + s.secure.HandlerFuncWithNextForRequestOnly(rw, req, func(writer http.ResponseWriter, request *http.Request) { + s.next.ServeHTTP(newResponseModifier(writer, request, s.secure.ModifyResponseHeaders), request) + }) +} diff --git a/pkg/middlewares/headers/secure_test.go b/pkg/middlewares/headers/secure_test.go new file mode 100644 index 000000000..66cc564d9 --- /dev/null +++ b/pkg/middlewares/headers/secure_test.go @@ -0,0 +1,191 @@ +package headers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/containous/traefik/v2/pkg/config/dynamic" + "github.com/stretchr/testify/assert" +) + +// Middleware tests based on https://github.com/unrolled/secure + +func Test_newSecure_sslForceHost(t *testing.T) { + type expected struct { + statusCode int + location string + } + + testCases := []struct { + desc string + host string + cfg dynamic.Headers + expected + }{ + { + desc: "http should return a 301", + host: "http://powpow.example.com", + cfg: dynamic.Headers{ + SSLRedirect: true, + SSLForceHost: true, + SSLHost: "powpow.example.com", + }, + expected: expected{ + statusCode: http.StatusMovedPermanently, + location: "https://powpow.example.com", + }, + }, + { + desc: "http sub domain should return a 301", + host: "http://www.powpow.example.com", + cfg: dynamic.Headers{ + SSLRedirect: true, + SSLForceHost: true, + SSLHost: "powpow.example.com", + }, + expected: expected{ + statusCode: http.StatusMovedPermanently, + location: "https://powpow.example.com", + }, + }, + { + desc: "https should return a 200", + host: "https://powpow.example.com", + cfg: dynamic.Headers{ + SSLRedirect: true, + SSLForceHost: true, + SSLHost: "powpow.example.com", + }, + expected: expected{statusCode: http.StatusOK}, + }, + { + desc: "https sub domain should return a 301", + host: "https://www.powpow.example.com", + cfg: dynamic.Headers{ + SSLRedirect: true, + SSLForceHost: true, + SSLHost: "powpow.example.com", + }, + expected: expected{ + statusCode: http.StatusMovedPermanently, + location: "https://powpow.example.com", + }, + }, + { + desc: "http without force host and sub domain should return a 301", + host: "http://www.powpow.example.com", + cfg: dynamic.Headers{ + SSLRedirect: true, + SSLForceHost: false, + SSLHost: "powpow.example.com", + }, + expected: expected{ + statusCode: http.StatusMovedPermanently, + location: "https://powpow.example.com", + }, + }, + { + desc: "https without force host and sub domain should return a 301", + host: "https://www.powpow.example.com", + cfg: dynamic.Headers{ + SSLRedirect: true, + SSLForceHost: false, + SSLHost: "powpow.example.com", + }, + expected: expected{statusCode: http.StatusOK}, + }, + } + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + _, _ = rw.Write([]byte("OK")) + }) + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + mid := newSecure(next, test.cfg, "mymiddleware") + + req := httptest.NewRequest(http.MethodGet, test.host, nil) + + rw := httptest.NewRecorder() + + mid.ServeHTTP(rw, req) + + assert.Equal(t, test.expected.statusCode, rw.Result().StatusCode) + assert.Equal(t, test.expected.location, rw.Header().Get("Location")) + }) + } +} + +func Test_newSecure_modifyResponse(t *testing.T) { + testCases := []struct { + desc string + cfg dynamic.Headers + expected http.Header + }{ + { + desc: "FeaturePolicy", + cfg: dynamic.Headers{ + FeaturePolicy: "vibrate 'none';", + }, + expected: http.Header{"Feature-Policy": []string{"vibrate 'none';"}}, + }, + { + desc: "STSSeconds", + cfg: dynamic.Headers{ + STSSeconds: 1, + ForceSTSHeader: true, + }, + expected: http.Header{"Strict-Transport-Security": []string{"max-age=1"}}, + }, + { + desc: "STSSeconds and STSPreload", + cfg: dynamic.Headers{ + STSSeconds: 1, + ForceSTSHeader: true, + STSPreload: true, + }, + expected: http.Header{"Strict-Transport-Security": []string{"max-age=1; preload"}}, + }, + { + desc: "CustomFrameOptionsValue", + cfg: dynamic.Headers{ + CustomFrameOptionsValue: "foo", + }, + expected: http.Header{"X-Frame-Options": []string{"foo"}}, + }, + { + desc: "FrameDeny", + cfg: dynamic.Headers{ + FrameDeny: true, + }, + expected: http.Header{"X-Frame-Options": []string{"DENY"}}, + }, + { + desc: "ContentTypeNosniff", + cfg: dynamic.Headers{ + ContentTypeNosniff: true, + }, + expected: http.Header{"X-Content-Type-Options": []string{"nosniff"}}, + }, + } + + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + secure := newSecure(emptyHandler, test.cfg, "mymiddleware") + + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + + rw := httptest.NewRecorder() + + secure.ServeHTTP(rw, req) + + assert.Equal(t, test.expected, rw.Result().Header) + }) + } +} diff --git a/pkg/responsemodifiers/headers.go b/pkg/responsemodifiers/headers.go deleted file mode 100644 index 403785eea..000000000 --- a/pkg/responsemodifiers/headers.go +++ /dev/null @@ -1,53 +0,0 @@ -package responsemodifiers - -import ( - "net/http" - - "github.com/containous/traefik/v2/pkg/config/dynamic" - "github.com/containous/traefik/v2/pkg/middlewares/headers" - "github.com/unrolled/secure" -) - -func buildHeaders(hdrs *dynamic.Headers) func(*http.Response) error { - opt := secure.Options{ - BrowserXssFilter: hdrs.BrowserXSSFilter, - ContentTypeNosniff: hdrs.ContentTypeNosniff, - ForceSTSHeader: hdrs.ForceSTSHeader, - FrameDeny: hdrs.FrameDeny, - IsDevelopment: hdrs.IsDevelopment, - SSLRedirect: hdrs.SSLRedirect, - SSLForceHost: hdrs.SSLForceHost, - SSLTemporaryRedirect: hdrs.SSLTemporaryRedirect, - STSIncludeSubdomains: hdrs.STSIncludeSubdomains, - STSPreload: hdrs.STSPreload, - ContentSecurityPolicy: hdrs.ContentSecurityPolicy, - CustomBrowserXssValue: hdrs.CustomBrowserXSSValue, - CustomFrameOptionsValue: hdrs.CustomFrameOptionsValue, - PublicKey: hdrs.PublicKey, - ReferrerPolicy: hdrs.ReferrerPolicy, - SSLHost: hdrs.SSLHost, - AllowedHosts: hdrs.AllowedHosts, - HostsProxyHeaders: hdrs.HostsProxyHeaders, - SSLProxyHeaders: hdrs.SSLProxyHeaders, - STSSeconds: hdrs.STSSeconds, - FeaturePolicy: hdrs.FeaturePolicy, - } - - return func(resp *http.Response) error { - if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() { - err := headers.NewHeader(nil, *hdrs).PostRequestModifyResponseHeaders(resp) - if err != nil { - return err - } - } - - if hdrs.HasSecureHeadersDefined() { - err := secure.New(opt).ModifyResponseHeaders(resp) - if err != nil { - return err - } - } - - return nil - } -} diff --git a/pkg/responsemodifiers/log.go b/pkg/responsemodifiers/log.go deleted file mode 100644 index 43eac2a1b..000000000 --- a/pkg/responsemodifiers/log.go +++ /dev/null @@ -1,13 +0,0 @@ -package responsemodifiers - -import ( - "context" - - "github.com/containous/traefik/v2/pkg/log" - "github.com/sirupsen/logrus" -) - -// getLogger creates a logger configured with the middleware fields. -func getLogger(ctx context.Context, middleware, middlewareType string) logrus.FieldLogger { - return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType) -} diff --git a/pkg/responsemodifiers/response_modifier.go b/pkg/responsemodifiers/response_modifier.go deleted file mode 100644 index b6bd351be..000000000 --- a/pkg/responsemodifiers/response_modifier.go +++ /dev/null @@ -1,68 +0,0 @@ -package responsemodifiers - -import ( - "context" - "net/http" - - "github.com/containous/traefik/v2/pkg/config/runtime" - "github.com/containous/traefik/v2/pkg/server/provider" -) - -// NewBuilder creates a builder. -func NewBuilder(configs map[string]*runtime.MiddlewareInfo) *Builder { - return &Builder{configs: configs} -} - -// Builder holds builder configuration. -type Builder struct { - configs map[string]*runtime.MiddlewareInfo -} - -// Build Builds the response modifier. -// It returns nil if there is no modifier to apply. -func (f *Builder) Build(ctx context.Context, names []string) func(*http.Response) error { - var modifiers []func(*http.Response) error - - for _, middleName := range names { - conf, ok := f.configs[middleName] - if !ok { - getLogger(ctx, middleName, "undefined").Debug("Middleware name not found in config (ResponseModifier)") - continue - } - if conf == nil || conf.Middleware == nil { - getLogger(ctx, middleName, "undefined").Error("Invalid Middleware configuration (ResponseModifier)") - continue - } - - if conf.Headers != nil { - getLogger(ctx, middleName, "Headers").Debug("Creating Middleware (ResponseModifier)") - - modifiers = append(modifiers, buildHeaders(conf.Headers)) - } else if conf.Chain != nil { - chainCtx := provider.AddInContext(ctx, middleName) - getLogger(chainCtx, middleName, "Chain").Debug("Creating Middleware (ResponseModifier)") - var qualifiedNames []string - for _, name := range conf.Chain.Middlewares { - qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(chainCtx, name)) - } - - if rm := f.Build(ctx, qualifiedNames); rm != nil { - modifiers = append(modifiers, rm) - } - } - } - - if len(modifiers) > 0 { - return func(resp *http.Response) error { - for i := len(modifiers); i > 0; i-- { - err := modifiers[i-1](resp) - if err != nil { - return err - } - } - return nil - } - } - - return nil -} diff --git a/pkg/responsemodifiers/response_modifier_test.go b/pkg/responsemodifiers/response_modifier_test.go deleted file mode 100644 index 7b3822831..000000000 --- a/pkg/responsemodifiers/response_modifier_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package responsemodifiers - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/containous/traefik/v2/pkg/config/dynamic" - "github.com/containous/traefik/v2/pkg/config/runtime" - "github.com/containous/traefik/v2/pkg/middlewares/headers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func stubResponse(_ map[string]*dynamic.Middleware) *http.Response { - return &http.Response{Header: make(http.Header)} -} - -func TestBuilderBuild(t *testing.T) { - testCases := []struct { - desc string - middlewares []string - // buildResponse is needed because secure use a private context key - buildResponse func(map[string]*dynamic.Middleware) *http.Response - conf map[string]*dynamic.Middleware - assertResponse func(*testing.T, *http.Response) - }{ - { - desc: "no configuration", - middlewares: []string{"foo", "bar"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{}, - assertResponse: func(t *testing.T, resp *http.Response) {}, - }, - { - desc: "one modifier", - middlewares: []string{"foo", "bar"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{ - "foo": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Foo": "foo"}, - }, - }, - }, - assertResponse: func(t *testing.T, resp *http.Response) { - t.Helper() - - assert.Equal(t, "foo", resp.Header.Get("X-Foo")) - }, - }, - { - desc: "secure: one modifier", - middlewares: []string{"foo", "bar"}, - buildResponse: func(middlewares map[string]*dynamic.Middleware) *http.Response { - ctx := context.Background() - - var request *http.Request - next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - request = req - }) - - headerM := *middlewares["foo"].Headers - handler, err := headers.New(ctx, next, headerM, "secure") - require.NoError(t, err) - - handler.ServeHTTP(httptest.NewRecorder(), - httptest.NewRequest(http.MethodGet, "http://foo.com", nil)) - - return &http.Response{Header: make(http.Header), Request: request} - }, - conf: map[string]*dynamic.Middleware{ - "foo": { - Headers: &dynamic.Headers{ - ReferrerPolicy: "no-referrer", - }, - }, - "bar": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Bar": "bar"}, - }, - }, - }, - assertResponse: func(t *testing.T, resp *http.Response) { - t.Helper() - - assert.Equal(t, "no-referrer", resp.Header.Get("Referrer-Policy")) - }, - }, - { - desc: "two modifiers", - middlewares: []string{"foo", "bar"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{ - "foo": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Foo": "foo"}, - }, - }, - "bar": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Bar": "bar"}, - }, - }, - }, - assertResponse: func(t *testing.T, resp *http.Response) { - t.Helper() - - assert.Equal(t, "foo", resp.Header.Get("X-Foo")) - assert.Equal(t, "bar", resp.Header.Get("X-Bar")) - }, - }, - { - desc: "modifier order", - middlewares: []string{"foo", "bar"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{ - "foo": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Foo": "foo"}, - }, - }, - "bar": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Foo": "bar"}, - }, - }, - }, - assertResponse: func(t *testing.T, resp *http.Response) { - t.Helper() - - assert.Equal(t, "foo", resp.Header.Get("X-Foo")) - }, - }, - { - desc: "chain", - middlewares: []string{"chain"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{ - "foo": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Foo": "foo"}, - }, - }, - "bar": { - Headers: &dynamic.Headers{ - CustomResponseHeaders: map[string]string{"X-Foo": "bar"}, - }, - }, - "chain": { - Chain: &dynamic.Chain{ - Middlewares: []string{"foo", "bar"}, - }, - }, - }, - assertResponse: func(t *testing.T, resp *http.Response) { - t.Helper() - - assert.Equal(t, "foo", resp.Header.Get("X-Foo")) - }, - }, - { - desc: "nil middleware", - middlewares: []string{"foo"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{ - "foo": nil, - }, - assertResponse: func(t *testing.T, resp *http.Response) {}, - }, - - { - desc: "chain without headers", - middlewares: []string{"chain"}, - buildResponse: stubResponse, - conf: map[string]*dynamic.Middleware{ - "foo": {IPWhiteList: &dynamic.IPWhiteList{}}, - "chain": { - Chain: &dynamic.Chain{ - Middlewares: []string{"foo"}, - }, - }, - }, - assertResponse: func(t *testing.T, resp *http.Response) {}, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - rtConf := runtime.NewConfig(dynamic.Configuration{ - HTTP: &dynamic.HTTPConfiguration{ - Middlewares: test.conf, - }, - }) - builder := NewBuilder(rtConf.Middlewares) - - rm := builder.Build(context.Background(), test.middlewares) - if rm == nil { - return - } - - resp := test.buildResponse(test.conf) - - err := rm(resp) - require.NoError(t, err) - - test.assertResponse(t, resp) - }) - } -} diff --git a/pkg/server/middleware/middlewares.go b/pkg/server/middleware/middlewares.go index c1398d831..2fc0ccc3d 100644 --- a/pkg/server/middleware/middlewares.go +++ b/pkg/server/middleware/middlewares.go @@ -44,7 +44,7 @@ type Builder struct { } type serviceBuilder interface { - BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) + BuildHTTP(ctx context.Context, serviceName string) (http.Handler, error) } // NewBuilder creates a new Builder. diff --git a/pkg/server/router/router.go b/pkg/server/router/router.go index 3392e94ff..994e0a595 100644 --- a/pkg/server/router/router.go +++ b/pkg/server/router/router.go @@ -24,12 +24,8 @@ type middlewareBuilder interface { BuildChain(ctx context.Context, names []string) *alice.Chain } -type responseModifierBuilder interface { - Build(ctx context.Context, names []string) func(*http.Response) error -} - type serviceManager interface { - BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) + BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) LaunchHealthCheck() } @@ -39,22 +35,15 @@ type Manager struct { serviceManager serviceManager middlewaresBuilder middlewareBuilder chainBuilder *middleware.ChainBuilder - modifierBuilder responseModifierBuilder conf *runtime.Configuration } // NewManager Creates a new Manager. -func NewManager(conf *runtime.Configuration, - serviceManager serviceManager, - middlewaresBuilder middlewareBuilder, - modifierBuilder responseModifierBuilder, - chainBuilder *middleware.ChainBuilder, -) *Manager { +func NewManager(conf *runtime.Configuration, serviceManager serviceManager, middlewaresBuilder middlewareBuilder, chainBuilder *middleware.ChainBuilder) *Manager { return &Manager{ routerHandlers: make(map[string]http.Handler), serviceManager: serviceManager, middlewaresBuilder: middlewaresBuilder, - modifierBuilder: modifierBuilder, chainBuilder: chainBuilder, conf: conf, } @@ -176,13 +165,12 @@ func (m *Manager) buildHTTPHandler(ctx context.Context, router *runtime.RouterIn qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(ctx, name)) } router.Middlewares = qualifiedNames - rm := m.modifierBuilder.Build(ctx, qualifiedNames) if router.Service == "" { return nil, errors.New("the service is missing on the router") } - sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service, rm) + sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service) if err != nil { return nil, err } diff --git a/pkg/server/router/router_test.go b/pkg/server/router/router_test.go index 113335086..8febc0e67 100644 --- a/pkg/server/router/router_test.go +++ b/pkg/server/router/router_test.go @@ -13,7 +13,6 @@ import ( "github.com/containous/traefik/v2/pkg/config/static" "github.com/containous/traefik/v2/pkg/middlewares/accesslog" "github.com/containous/traefik/v2/pkg/middlewares/requestdecorator" - "github.com/containous/traefik/v2/pkg/responsemodifiers" "github.com/containous/traefik/v2/pkg/server/middleware" "github.com/containous/traefik/v2/pkg/server/service" "github.com/containous/traefik/v2/pkg/testhelpers" @@ -290,10 +289,9 @@ func TestRouterManager_Get(t *testing.T) { serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager) - responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares) chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil) - routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder) + routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder) handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false) @@ -395,10 +393,9 @@ func TestAccessLog(t *testing.T) { serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager) - responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares) chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil) - routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder) + routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder) handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false) @@ -683,10 +680,9 @@ func TestRuntimeConfiguration(t *testing.T) { serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager) - responseModifierFactory := responsemodifiers.NewBuilder(map[string]*runtime.MiddlewareInfo{}) chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil) - routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder) + routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder) _ = routerManager.BuildHandlers(context.Background(), entryPoints, false) @@ -765,10 +761,9 @@ func TestProviderOnMiddlewares(t *testing.T) { serviceManager := service.NewManager(rtConf.Services, http.DefaultTransport, nil, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager) - responseModifierFactory := responsemodifiers.NewBuilder(map[string]*runtime.MiddlewareInfo{}) chainBuilder := middleware.NewChainBuilder(staticCfg, nil, nil) - routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder) + routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder) _ = routerManager.BuildHandlers(context.Background(), entryPoints, false) @@ -826,10 +821,9 @@ func BenchmarkRouterServe(b *testing.B) { serviceManager := service.NewManager(rtConf.Services, &staticTransport{res}, nil, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager) - responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares) chainBuilder := middleware.NewChainBuilder(static.Configuration{}, nil, nil) - routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, chainBuilder) + routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, chainBuilder) handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false) @@ -871,7 +865,7 @@ func BenchmarkService(b *testing.B) { w := httptest.NewRecorder() req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) - handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service", nil) + handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service") b.ReportAllocs() for i := 0; i < b.N; i++ { handler.ServeHTTP(w, req) diff --git a/pkg/server/routerfactory.go b/pkg/server/routerfactory.go index 1a2dfdc1c..00d17b46c 100644 --- a/pkg/server/routerfactory.go +++ b/pkg/server/routerfactory.go @@ -7,7 +7,6 @@ import ( "github.com/containous/traefik/v2/pkg/config/runtime" "github.com/containous/traefik/v2/pkg/config/static" "github.com/containous/traefik/v2/pkg/log" - "github.com/containous/traefik/v2/pkg/responsemodifiers" "github.com/containous/traefik/v2/pkg/server/middleware" "github.com/containous/traefik/v2/pkg/server/router" routertcp "github.com/containous/traefik/v2/pkg/server/router/tcp" @@ -67,9 +66,8 @@ func (f *RouterFactory) CreateRouters(conf dynamic.Configuration) (map[string]*t serviceManager := f.managerFactory.Build(rtConf) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager) - responseModifierFactory := responsemodifiers.NewBuilder(rtConf.Middlewares) - routerManager := router.NewManager(rtConf, serviceManager, middlewaresBuilder, responseModifierFactory, f.chainBuilder) + routerManager := router.NewManager(rtConf, serviceManager, middlewaresBuilder, f.chainBuilder) handlersNonTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, false) handlersTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, true) diff --git a/pkg/server/service/internalhandler.go b/pkg/server/service/internalhandler.go index 6ff7e0a1e..f50e37065 100644 --- a/pkg/server/service/internalhandler.go +++ b/pkg/server/service/internalhandler.go @@ -8,11 +8,10 @@ import ( "strings" "github.com/containous/traefik/v2/pkg/config/runtime" - "github.com/containous/traefik/v2/pkg/log" ) type serviceManager interface { - BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) + BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) LaunchHealthCheck() } @@ -43,87 +42,18 @@ func NewInternalHandlers(api func(configuration *runtime.Configuration) http.Han } } -type responseModifier struct { - r *http.Request - w http.ResponseWriter - - headersSent bool // whether headers have already been sent - code int // status code, must default to 200 - - modifier func(*http.Response) error // can be nil - modified bool // whether modifier has already been called for the current request - modifierErr error // returned by modifier call -} - -// modifier can be nil. -func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier { - return &responseModifier{ - r: r, - w: w, - modifier: modifier, - code: http.StatusOK, - } -} - -func (w *responseModifier) WriteHeader(code int) { - if w.headersSent { - return - } - defer func() { - w.code = code - w.headersSent = true - }() - - if w.modifier == nil || w.modified { - w.w.WriteHeader(code) - return - } - - resp := http.Response{ - Header: w.w.Header(), - Request: w.r, - } - - if err := w.modifier(&resp); err != nil { - w.modifierErr = err - // we are propagating when we are called in Write, but we're logging anyway, - // because we could be called from another place which does not take care of - // checking w.modifierErr. - log.Errorf("Error when applying response modifier: %v", err) - w.w.WriteHeader(http.StatusInternalServerError) - return - } - - w.modified = true - w.w.WriteHeader(code) -} - -func (w *responseModifier) Header() http.Header { - return w.w.Header() -} - -func (w *responseModifier) Write(b []byte) (int, error) { - w.WriteHeader(w.code) - if w.modifierErr != nil { - return 0, w.modifierErr - } - - return w.w.Write(b) -} - // BuildHTTP builds an HTTP handler. -func (m *InternalHandlers) BuildHTTP(rootCtx context.Context, serviceName string, respModifier func(*http.Response) error) (http.Handler, error) { +func (m *InternalHandlers) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) { if !strings.HasSuffix(serviceName, "@internal") { - return m.serviceManager.BuildHTTP(rootCtx, serviceName, respModifier) + return m.serviceManager.BuildHTTP(rootCtx, serviceName) } internalHandler, err := m.get(serviceName) if err != nil { return nil, err } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - internalHandler.ServeHTTP(newResponseModifier(w, r, respModifier), r) - }), nil + + return internalHandler, nil } func (m *InternalHandlers) get(serviceName string) (http.Handler, error) { diff --git a/pkg/server/service/proxy.go b/pkg/server/service/proxy.go index 21a67329f..5e8c85407 100644 --- a/pkg/server/service/proxy.go +++ b/pkg/server/service/proxy.go @@ -21,7 +21,7 @@ const StatusClientClosedRequest = 499 // StatusClientClosedRequestText non-standard HTTP status for client disconnection. const StatusClientClosedRequestText = "Client Closed Request" -func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) { +func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool) (http.Handler, error) { var flushInterval types.Duration if responseForwarding != nil { err := flushInterval.Set(responseForwarding.FlushInterval) @@ -76,10 +76,9 @@ func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwar delete(outReq.Header, "Sec-Websocket-Protocol") delete(outReq.Header, "Sec-Websocket-Version") }, - Transport: defaultRoundTripper, - FlushInterval: time.Duration(flushInterval), - ModifyResponse: responseModifier, - BufferPool: bufferPool, + Transport: defaultRoundTripper, + FlushInterval: time.Duration(flushInterval), + BufferPool: bufferPool, ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) { statusCode := http.StatusInternalServerError diff --git a/pkg/server/service/proxy_test.go b/pkg/server/service/proxy_test.go index 38966b509..5318a0a67 100644 --- a/pkg/server/service/proxy_test.go +++ b/pkg/server/service/proxy_test.go @@ -28,7 +28,7 @@ func BenchmarkProxy(b *testing.B) { req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) pool := newBufferPool() - handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool, nil) + handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool) b.ReportAllocs() for i := 0; i < b.N; i++ { diff --git a/pkg/server/service/proxy_websocket_test.go b/pkg/server/service/proxy_websocket_test.go index d309e5379..fe169e1c5 100644 --- a/pkg/server/service/proxy_websocket_test.go +++ b/pkg/server/service/proxy_websocket_test.go @@ -20,7 +20,7 @@ import ( func Bool(v bool) *bool { return &v } func TestWebSocketTCPClose(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) errChan := make(chan error, 1) @@ -59,7 +59,7 @@ func TestWebSocketTCPClose(t *testing.T) { } func TestWebSocketPingPong(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) @@ -125,7 +125,7 @@ func TestWebSocketPingPong(t *testing.T) { } func TestWebSocketEcho(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) mux := http.NewServeMux() @@ -191,7 +191,7 @@ func TestWebSocketPassHost(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil) require.NoError(t, err) @@ -250,7 +250,7 @@ func TestWebSocketPassHost(t *testing.T) { } func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { @@ -291,7 +291,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { } func TestWebSocketRequestWithOrigin(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{} @@ -337,7 +337,7 @@ func TestWebSocketRequestWithOrigin(t *testing.T) { } func TestWebSocketRequestWithQueryParams(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{} @@ -377,7 +377,7 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) { } func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) mux := http.NewServeMux() @@ -409,7 +409,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { } func TestWebSocketRequestWithEncodedChar(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{} @@ -449,7 +449,7 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) { } func TestWebSocketUpgradeFailed(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) mux := http.NewServeMux() @@ -499,7 +499,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) { } func TestForwardsWebsocketTraffic(t *testing.T) { - f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) mux := http.NewServeMux() @@ -555,7 +555,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() defer srv.Close() - forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL) @@ -574,7 +574,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil, nil) + forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil) require.NoError(t, err) proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL) @@ -593,7 +593,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) { http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil, nil) + forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil) require.NoError(t, err) proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL) diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index 2c641cb0e..37f02c4db 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -62,7 +62,7 @@ type Manager struct { } // BuildHTTP Creates a http.Handler for a service configuration. -func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) { +func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) { ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName)) serviceName = provider.GetQualifiedName(ctx, serviceName) @@ -91,21 +91,21 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, respons switch { case conf.LoadBalancer != nil: var err error - lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier) + lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer) if err != nil { conf.AddError(err, true) return nil, err } case conf.Weighted != nil: var err error - lb, err = m.getWRRServiceHandler(ctx, serviceName, conf.Weighted, responseModifier) + lb, err = m.getWRRServiceHandler(ctx, serviceName, conf.Weighted) if err != nil { conf.AddError(err, true) return nil, err } case conf.Mirroring != nil: var err error - lb, err = m.getMirrorServiceHandler(ctx, conf.Mirroring, responseModifier) + lb, err = m.getMirrorServiceHandler(ctx, conf.Mirroring) if err != nil { conf.AddError(err, true) return nil, err @@ -119,8 +119,8 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, respons return lb, nil } -func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.Mirroring, responseModifier func(*http.Response) error) (http.Handler, error) { - serviceHandler, err := m.BuildHTTP(ctx, config.Service, responseModifier) +func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.Mirroring) (http.Handler, error) { + serviceHandler, err := m.BuildHTTP(ctx, config.Service) if err != nil { return nil, err } @@ -131,7 +131,7 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M } handler := mirror.New(serviceHandler, m.routinePool, maxBodySize) for _, mirrorConfig := range config.Mirrors { - mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name, responseModifier) + mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name) if err != nil { return nil, err } @@ -144,7 +144,7 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M return handler, nil } -func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, config *dynamic.WeightedRoundRobin, responseModifier func(*http.Response) error) (http.Handler, error) { +func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, config *dynamic.WeightedRoundRobin) (http.Handler, error) { // TODO Handle accesslog and metrics with multiple service name if config.Sticky != nil && config.Sticky.Cookie != nil { config.Sticky.Cookie.Name = cookie.GetName(config.Sticky.Cookie.Name, serviceName) @@ -152,7 +152,7 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, balancer := wrr.New(config.Sticky) for _, service := range config.Services { - serviceHandler, err := m.BuildHTTP(ctx, service.Name, responseModifier) + serviceHandler, err := m.BuildHTTP(ctx, service.Name) if err != nil { return nil, err } @@ -162,18 +162,13 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, return balancer, nil } -func (m *Manager) getLoadBalancerServiceHandler( - ctx context.Context, - serviceName string, - service *dynamic.ServersLoadBalancer, - responseModifier func(*http.Response) error, -) (http.Handler, error) { +func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer) (http.Handler, error) { if service.PassHostHeader == nil { defaultPassHostHeader := true service.PassHostHeader = &defaultPassHostHeader } - fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier) + fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool) if err != nil { return nil, err } diff --git a/pkg/server/service/service_test.go b/pkg/server/service/service_test.go index 522f5308b..23b5f6b35 100644 --- a/pkg/server/service/service_test.go +++ b/pkg/server/service/service_test.go @@ -259,7 +259,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { - handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier) + handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service) assert.NoError(t, err) assert.NotNil(t, handler) @@ -339,7 +339,7 @@ func TestManager_Build(t *testing.T) { ctx = provider.AddInContext(ctx, "foobar@"+test.providerName) } - _, err := manager.BuildHTTP(ctx, test.serviceName, nil) + _, err := manager.BuildHTTP(ctx, test.serviceName) require.NoError(t, err) }) } @@ -357,7 +357,7 @@ func TestMultipleTypeOnBuildHTTP(t *testing.T) { manager := NewManager(services, http.DefaultTransport, nil, nil) - _, err := manager.BuildHTTP(context.Background(), "test@file", nil) + _, err := manager.BuildHTTP(context.Background(), "test@file") assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead") }