diff --git a/middlewares/auth/forward.go b/middlewares/auth/forward.go index 86c365e8c..5292c282c 100644 --- a/middlewares/auth/forward.go +++ b/middlewares/auth/forward.go @@ -105,6 +105,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) { utils.CopyHeaders(forwardReq.Header, req.Header) + utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...) if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if trustForwardHeader { diff --git a/middlewares/auth/forward_test.go b/middlewares/auth/forward_test.go index bf0ada6fb..a0364a030 100644 --- a/middlewares/auth/forward_test.go +++ b/middlewares/auth/forward_test.go @@ -231,11 +231,12 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) { func Test_writeHeader(t *testing.T) { testCases := []struct { - name string - headers map[string]string - trustForwardHeader bool - emptyHost bool - expectedHeaders map[string]string + name string + headers map[string]string + trustForwardHeader bool + emptyHost bool + expectedHeaders map[string]string + checkForUnexpectedHeaders bool }{ { name: "trust Forward Header", @@ -334,6 +335,29 @@ func Test_writeHeader(t *testing.T) { "X-Forwarded-Method": "GET", }, }, + { + name: "remove hop-by-hop headers", + headers: map[string]string{ + forward.Connection: "Connection", + forward.KeepAlive: "KeepAlive", + forward.ProxyAuthenticate: "ProxyAuthenticate", + forward.ProxyAuthorization: "ProxyAuthorization", + forward.Te: "Te", + forward.Trailers: "Trailers", + forward.TransferEncoding: "TransferEncoding", + forward.Upgrade: "Upgrade", + "X-CustomHeader": "CustomHeader", + }, + trustForwardHeader: false, + expectedHeaders: map[string]string{ + "X-CustomHeader": "CustomHeader", + "X-Forwarded-Proto": "http", + "X-Forwarded-Host": "foo.bar", + "X-Forwarded-Uri": "/path?q=1", + "X-Forwarded-Method": "GET", + }, + checkForUnexpectedHeaders: true, + }, } for _, test := range testCases { @@ -352,8 +376,16 @@ func Test_writeHeader(t *testing.T) { writeHeader(req, forwardReq, test.trustForwardHeader) - for key, value := range test.expectedHeaders { - assert.Equal(t, value, forwardReq.Header.Get(key)) + actualHeaders := forwardReq.Header + expectedHeaders := test.expectedHeaders + for key, value := range expectedHeaders { + assert.Equal(t, value, actualHeaders.Get(key)) + actualHeaders.Del(key) + } + if test.checkForUnexpectedHeaders { + for key := range actualHeaders { + assert.Fail(t, "Unexpected header found", key) + } } }) }