Do not copy hop-by-hop headers to forward auth request
This commit is contained in:
parent
1f1ecb15f6
commit
29473ef356
2 changed files with 40 additions and 7 deletions
|
@ -105,6 +105,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
||||||
|
|
||||||
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
|
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
|
||||||
utils.CopyHeaders(forwardReq.Header, req.Header)
|
utils.CopyHeaders(forwardReq.Header, req.Header)
|
||||||
|
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
|
||||||
|
|
||||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||||
if trustForwardHeader {
|
if trustForwardHeader {
|
||||||
|
|
|
@ -231,11 +231,12 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
||||||
|
|
||||||
func Test_writeHeader(t *testing.T) {
|
func Test_writeHeader(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
trustForwardHeader bool
|
trustForwardHeader bool
|
||||||
emptyHost bool
|
emptyHost bool
|
||||||
expectedHeaders map[string]string
|
expectedHeaders map[string]string
|
||||||
|
checkForUnexpectedHeaders bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "trust Forward Header",
|
name: "trust Forward Header",
|
||||||
|
@ -334,6 +335,29 @@ func Test_writeHeader(t *testing.T) {
|
||||||
"X-Forwarded-Method": "GET",
|
"X-Forwarded-Method": "GET",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "remove hop-by-hop headers",
|
||||||
|
headers: map[string]string{
|
||||||
|
forward.Connection: "Connection",
|
||||||
|
forward.KeepAlive: "KeepAlive",
|
||||||
|
forward.ProxyAuthenticate: "ProxyAuthenticate",
|
||||||
|
forward.ProxyAuthorization: "ProxyAuthorization",
|
||||||
|
forward.Te: "Te",
|
||||||
|
forward.Trailers: "Trailers",
|
||||||
|
forward.TransferEncoding: "TransferEncoding",
|
||||||
|
forward.Upgrade: "Upgrade",
|
||||||
|
"X-CustomHeader": "CustomHeader",
|
||||||
|
},
|
||||||
|
trustForwardHeader: false,
|
||||||
|
expectedHeaders: map[string]string{
|
||||||
|
"X-CustomHeader": "CustomHeader",
|
||||||
|
"X-Forwarded-Proto": "http",
|
||||||
|
"X-Forwarded-Host": "foo.bar",
|
||||||
|
"X-Forwarded-Uri": "/path?q=1",
|
||||||
|
"X-Forwarded-Method": "GET",
|
||||||
|
},
|
||||||
|
checkForUnexpectedHeaders: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range testCases {
|
for _, test := range testCases {
|
||||||
|
@ -352,8 +376,16 @@ func Test_writeHeader(t *testing.T) {
|
||||||
|
|
||||||
writeHeader(req, forwardReq, test.trustForwardHeader)
|
writeHeader(req, forwardReq, test.trustForwardHeader)
|
||||||
|
|
||||||
for key, value := range test.expectedHeaders {
|
actualHeaders := forwardReq.Header
|
||||||
assert.Equal(t, value, forwardReq.Header.Get(key))
|
expectedHeaders := test.expectedHeaders
|
||||||
|
for key, value := range expectedHeaders {
|
||||||
|
assert.Equal(t, value, actualHeaders.Get(key))
|
||||||
|
actualHeaders.Del(key)
|
||||||
|
}
|
||||||
|
if test.checkForUnexpectedHeaders {
|
||||||
|
for key := range actualHeaders {
|
||||||
|
assert.Fail(t, "Unexpected header found", key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue