From f6181ef3e2e7ae65f04fc3c600cdd198c28f0fba Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 23 Nov 2017 17:40:03 +0100 Subject: [PATCH] Fix custom headers replacement --- docs/basics.md | 20 +++++++++- middlewares/headers.go | 53 +++++++++++--------------- middlewares/headers_test.go | 75 ++++++++++++++++++++++++++++++++----- server/server.go | 11 +++++- 4 files changed, 115 insertions(+), 44 deletions(-) diff --git a/docs/basics.md b/docs/basics.md index 1c9dde913..cddabeb78 100644 --- a/docs/basics.md +++ b/docs/basics.md @@ -261,6 +261,11 @@ Here, `frontend1` will be matched before `frontend2` (`10 > 5`). Custom headers can be configured through the frontends, to add headers to either requests or responses that match the frontend's rules. This allows for setting headers such as `X-Script-Name` to be added to the request, or custom headers to be added to the response. +!!! warning + If the custom header name is the same as one header name of the request or response, it will be replaced. + +In this example, all matches to the path `/cheese` will have the `X-Script-Name` header added to the proxied request, and the `X-Custom-Response-Header` added to the response. + ```toml [frontends] [frontends.frontend1] @@ -273,7 +278,20 @@ This allows for setting headers such as `X-Script-Name` to be added to the reque rule = "PathPrefixStrip:/cheese" ``` -In this example, all matches to the path `/cheese` will have the `X-Script-Name` header added to the proxied request, and the `X-Custom-Response-Header` added to the response. +In this second example, all matches to the path `/cheese` will have the `X-Script-Name` header added to the proxied request, the `X-Custom-Request-Header` removed to the request and the `X-Custom-Response-Header` removed to the response. + +```toml +[frontends] + [frontends.frontend1] + backend = "backend1" + [frontends.frontend1.headers.customresponseheaders] + X-Custom-Response-Header = "" + [frontends.frontend1.headers.customrequestheaders] + X-Script-Name = "test" + X-Custom-Request-Header = "" + [frontends.frontend1.routes.test_1] + rule = "PathPrefixStrip:/cheese" +``` #### Security headers diff --git a/middlewares/headers.go b/middlewares/headers.go index de4a251fb..6ddef1ca2 100644 --- a/middlewares/headers.go +++ b/middlewares/headers.go @@ -35,46 +35,35 @@ func NewHeaderFromStruct(headers types.Headers) *HeaderStruct { } } -// NewHeader constructs a new header instance with supplied options. -func NewHeader(options ...HeaderOptions) *HeaderStruct { - var o HeaderOptions - if len(options) == 0 { - o = HeaderOptions{} - } else { - o = options[0] - } - - return &HeaderStruct{ - opt: o, - } -} - -// Handler implements the http.HandlerFunc for integration with the standard net/http lib. -func (s *HeaderStruct) Handler(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Let headers process the request. - s.Process(w, r) - h.ServeHTTP(w, r) - }) -} - func (s *HeaderStruct) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - s.Process(w, r) + s.ModifyRequestHeaders(r) // If there is a next, call it. if next != nil { next(w, r) } } -// Process runs the actual checks and returns an error if the middleware chain should stop. -func (s *HeaderStruct) Process(w http.ResponseWriter, r *http.Request) { +// ModifyRequestHeaders set or delete request headers +func (s *HeaderStruct) ModifyRequestHeaders(r *http.Request) { // Loop through Custom request headers for header, value := range s.opt.CustomRequestHeaders { - r.Header.Set(header, value) - } - - // Loop through Custom response headers - for header, value := range s.opt.CustomResponseHeaders { - w.Header().Add(header, value) + if value == "" { + r.Header.Del(header) + } else { + r.Header.Set(header, value) + } } } + +// ModifyResponseHeaders set or delete response headers +func (s *HeaderStruct) ModifyResponseHeaders(res *http.Response) error { + // Loop through Custom response headers + for header, value := range s.opt.CustomResponseHeaders { + if value == "" { + res.Header.Del(header) + } else { + res.Header.Set(header, value) + } + } + return nil +} diff --git a/middlewares/headers_test.go b/middlewares/headers_test.go index a78a0c2ca..011859837 100644 --- a/middlewares/headers_test.go +++ b/middlewares/headers_test.go @@ -1,6 +1,6 @@ package middlewares -//Middleware tests based on https://github.com/unrolled/secure +// Middleware tests based on https://github.com/unrolled/secure import ( "net/http" @@ -15,36 +15,66 @@ var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("bar")) }) +// newHeader constructs a new header instance with supplied options. +func newHeader(options ...HeaderOptions) *HeaderStruct { + var o HeaderOptions + if len(options) == 0 { + o = HeaderOptions{} + } else { + o = options[0] + } + + return &HeaderStruct{ + opt: o, + } +} + func TestNoConfig(t *testing.T) { - s := NewHeader() + header := newHeader() res := httptest.NewRecorder() req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil) - s.Handler(myHandler).ServeHTTP(res, req) + header.ServeHTTP(res, req, myHandler) assert.Equal(t, http.StatusOK, res.Code, "Status not OK") assert.Equal(t, "bar", res.Body.String(), "Body not the expected") } -func TestCustomResponseHeader(t *testing.T) { - s := NewHeader(HeaderOptions{ +func TestModifyResponseHeaders(t *testing.T) { + header := newHeader(HeaderOptions{ CustomResponseHeaders: map[string]string{ "X-Custom-Response-Header": "test_response", }, }) res := httptest.NewRecorder() - req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) + res.HeaderMap.Add("X-Custom-Response-Header", "test_response") - s.Handler(myHandler).ServeHTTP(res, req) + header.ModifyResponseHeaders(res.Result()) assert.Equal(t, http.StatusOK, res.Code, "Status not OK") assert.Equal(t, "test_response", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header") + + res = httptest.NewRecorder() + res.HeaderMap.Add("X-Custom-Response-Header", "") + + header.ModifyResponseHeaders(res.Result()) + + assert.Equal(t, http.StatusOK, res.Code, "Status not OK") + assert.Equal(t, "", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header") + + res = httptest.NewRecorder() + res.HeaderMap.Add("X-Custom-Response-Header", "test_override") + + header.ModifyResponseHeaders(res.Result()) + + assert.Equal(t, http.StatusOK, res.Code, "Status not OK") + assert.Equal(t, "test_override", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header") } func TestCustomRequestHeader(t *testing.T) { - s := NewHeader(HeaderOptions{ + header := newHeader(HeaderOptions{ CustomRequestHeaders: map[string]string{ "X-Custom-Request-Header": "test_request", }, @@ -53,8 +83,35 @@ func TestCustomRequestHeader(t *testing.T) { res := httptest.NewRecorder() req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) - s.Handler(myHandler).ServeHTTP(res, req) + header.ServeHTTP(res, req, nil) assert.Equal(t, http.StatusOK, res.Code, "Status not OK") assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header") } + +func TestCustomRequestHeaderEmptyValue(t *testing.T) { + header := newHeader(HeaderOptions{ + CustomRequestHeaders: map[string]string{ + "X-Custom-Request-Header": "test_request", + }, + }) + + res := httptest.NewRecorder() + req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) + + header.ServeHTTP(res, req, nil) + + assert.Equal(t, http.StatusOK, res.Code, "Status not OK") + assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header") + + header = newHeader(HeaderOptions{ + CustomRequestHeaders: map[string]string{ + "X-Custom-Request-Header": "", + }, + }) + + header.ServeHTTP(res, req, nil) + + assert.Equal(t, http.StatusOK, res.Code, "Status not OK") + assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"), "This header is not expected") +} diff --git a/server/server.go b/server/server.go index e35dd205f..58914e710 100644 --- a/server/server.go +++ b/server/server.go @@ -960,12 +960,20 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf continue frontend } + var headerMiddleware *middlewares.HeaderStruct + var responseModifier func(res *http.Response) error + if frontend.Headers.HasCustomHeadersDefined() { + headerMiddleware = middlewares.NewHeaderFromStruct(frontend.Headers) + responseModifier = headerMiddleware.ModifyResponseHeaders + } + fwd, err := forward.New( forward.Stream(true), forward.PassHostHeader(frontend.PassHostHeader), forward.RoundTripper(roundTripper), forward.ErrorHandler(errorHandler), forward.Rewriter(rewriter), + forward.ResponseModifier(responseModifier), ) if err != nil { @@ -1140,8 +1148,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf } } - if frontend.Headers.HasCustomHeadersDefined() { - headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers) + if headerMiddleware != nil { log.Debugf("Adding header middleware for frontend %s", frontendName) n.Use(headerMiddleware) }