diff --git a/pkg/middlewares/customerrors/custom_errors.go b/pkg/middlewares/customerrors/custom_errors.go index 3cc0f494d..8a4431e8d 100644 --- a/pkg/middlewares/customerrors/custom_errors.go +++ b/pkg/middlewares/customerrors/custom_errors.go @@ -138,7 +138,6 @@ type codeCatcher struct { headerMap http.Header code int httpCodeRanges types.HTTPCodeRanges - firstWrite bool caughtFilteredCode bool responseWriter http.ResponseWriter headersSent bool @@ -160,7 +159,6 @@ func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200. responseWriter: rw, httpCodeRanges: httpCodeRanges, - firstWrite: true, } if _, ok := rw.(http.CloseNotifier); ok { return &codeCatcherWithCloseNotify{catcher} @@ -187,22 +185,14 @@ func (cc *codeCatcher) isFilteredCode() bool { } func (cc *codeCatcher) Write(buf []byte) (int, error) { - if !cc.firstWrite { - if cc.caughtFilteredCode { - // We don't care about the contents of the response, - // since we want to serve the ones from the error page, - // so we just drop them. - return len(buf), nil - } - return cc.responseWriter.Write(buf) - } - cc.firstWrite = false - // If WriteHeader was already called from the caller, this is a NOOP. // Otherwise, cc.code is actually a 200 here. cc.WriteHeader(cc.code) if cc.caughtFilteredCode { + // We don't care about the contents of the response, + // since we want to serve the ones from the error page, + // so we just drop them. return len(buf), nil } return cc.responseWriter.Write(buf) @@ -217,14 +207,12 @@ func (cc *codeCatcher) WriteHeader(code int) { for _, block := range cc.httpCodeRanges { if cc.code >= block[0] && cc.code <= block[1] { cc.caughtFilteredCode = true - break + // it will be up to the caller to send the headers, + // so it is out of our hands now. + return } } - // it will be up to the other response recorder to send the headers, - // so it is out of our hands now. - if cc.caughtFilteredCode { - return - } + utils.CopyHeaders(cc.responseWriter.Header(), cc.Header()) cc.responseWriter.WriteHeader(cc.code) cc.headersSent = true @@ -328,6 +316,8 @@ func (r *codeModifierWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, // Flush sends any buffered data to the client. func (r *codeModifierWithoutCloseNotify) Flush() { + r.WriteHeader(r.code) + if flusher, ok := r.responseWriter.(http.Flusher); ok { flusher.Flush() }