Do not copy hop-by-hop headers to forward auth request

This commit is contained in:
stffabi 2018-09-24 10:42:03 +02:00 committed by Traefiker Bot
parent 1f1ecb15f6
commit 29473ef356
2 changed files with 40 additions and 7 deletions

View file

@ -105,6 +105,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
utils.CopyHeaders(forwardReq.Header, req.Header)
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if trustForwardHeader {

View file

@ -231,11 +231,12 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
func Test_writeHeader(t *testing.T) {
testCases := []struct {
name string
headers map[string]string
trustForwardHeader bool
emptyHost bool
expectedHeaders map[string]string
name string
headers map[string]string
trustForwardHeader bool
emptyHost bool
expectedHeaders map[string]string
checkForUnexpectedHeaders bool
}{
{
name: "trust Forward Header",
@ -334,6 +335,29 @@ func Test_writeHeader(t *testing.T) {
"X-Forwarded-Method": "GET",
},
},
{
name: "remove hop-by-hop headers",
headers: map[string]string{
forward.Connection: "Connection",
forward.KeepAlive: "KeepAlive",
forward.ProxyAuthenticate: "ProxyAuthenticate",
forward.ProxyAuthorization: "ProxyAuthorization",
forward.Te: "Te",
forward.Trailers: "Trailers",
forward.TransferEncoding: "TransferEncoding",
forward.Upgrade: "Upgrade",
"X-CustomHeader": "CustomHeader",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-CustomHeader": "CustomHeader",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
"X-Forwarded-Method": "GET",
},
checkForUnexpectedHeaders: true,
},
}
for _, test := range testCases {
@ -352,8 +376,16 @@ func Test_writeHeader(t *testing.T) {
writeHeader(req, forwardReq, test.trustForwardHeader)
for key, value := range test.expectedHeaders {
assert.Equal(t, value, forwardReq.Header.Get(key))
actualHeaders := forwardReq.Header
expectedHeaders := test.expectedHeaders
for key, value := range expectedHeaders {
assert.Equal(t, value, actualHeaders.Get(key))
actualHeaders.Del(key)
}
if test.checkForUnexpectedHeaders {
for key := range actualHeaders {
assert.Fail(t, "Unexpected header found", key)
}
}
})
}