From 09224e4b04195d4bb5cbea868b768c84c22a92c6 Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Wed, 18 Mar 2020 00:54:04 +0100 Subject: [PATCH] fix: custom Host header. --- pkg/middlewares/headers/headers.go | 9 +++-- pkg/middlewares/headers/headers_test.go | 44 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/pkg/middlewares/headers/headers.go b/pkg/middlewares/headers/headers.go index 4868c8cfd..43faa6dff 100644 --- a/pkg/middlewares/headers/headers.go +++ b/pkg/middlewares/headers/headers.go @@ -165,9 +165,14 @@ func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (s *Header) modifyCustomRequestHeaders(req *http.Request) { // Loop through Custom request headers for header, value := range s.headers.CustomRequestHeaders { - if value == "" { + switch { + case value == "": req.Header.Del(header) - } else { + + case strings.EqualFold(header, "Host"): + req.Host = value + + default: req.Header.Set(header, value) } } diff --git a/pkg/middlewares/headers/headers_test.go b/pkg/middlewares/headers/headers_test.go index f9db565ce..b5294aa59 100644 --- a/pkg/middlewares/headers/headers_test.go +++ b/pkg/middlewares/headers/headers_test.go @@ -33,6 +33,50 @@ func TestCustomRequestHeader(t *testing.T) { assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header")) } +func TestCustomRequestHeader_Host(t *testing.T) { + emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + testCases := []struct { + desc string + customHeaders map[string]string + expectedHost string + expectedURLHost string + }{ + { + desc: "standard Host header", + customHeaders: map[string]string{}, + expectedHost: "example.org", + expectedURLHost: "example.org", + }, + { + desc: "custom Host header", + customHeaders: map[string]string{ + "Host": "example.com", + }, + expectedHost: "example.com", + expectedURLHost: "example.org", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + header := NewHeader(emptyHandler, dynamic.Headers{ + CustomRequestHeaders: test.customHeaders, + }) + + res := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://example.org/foo", nil) + require.NoError(t, err) + + header.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, test.expectedHost, req.Host) + assert.Equal(t, test.expectedURLHost, req.URL.Host) + }) + } +} + func TestCustomRequestHeaderEmptyValue(t *testing.T) { emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})