diff --git a/middlewares/errorpages/error_pages.go b/middlewares/errorpages/error_pages.go index fb262081e..9fbe84706 100644 --- a/middlewares/errorpages/error_pages.go +++ b/middlewares/errorpages/error_pages.go @@ -101,12 +101,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http. h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) - utils.CopyHeaders(w.Header(), recorder.Header()) - for key := range recorderErrorPage.Header() { - w.Header().Del(key) - } utils.CopyHeaders(w.Header(), recorderErrorPage.Header()) - w.WriteHeader(recorder.GetCode()) w.Write(recorderErrorPage.GetBody().Bytes()) return @@ -174,64 +169,65 @@ type responseRecorderWithCloseNotify struct { // CloseNotify returns a channel that receives at most a // single value (true) when the client connection has gone away. -func (rw *responseRecorderWithCloseNotify) CloseNotify() <-chan bool { - return rw.responseWriter.(http.CloseNotifier).CloseNotify() +func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool { + return r.responseWriter.(http.CloseNotifier).CloseNotify() } // Header returns the response headers. -func (rw *responseRecorderWithoutCloseNotify) Header() http.Header { - if rw.HeaderMap == nil { - rw.HeaderMap = make(http.Header) +func (r *responseRecorderWithoutCloseNotify) Header() http.Header { + if r.HeaderMap == nil { + r.HeaderMap = make(http.Header) } - return rw.HeaderMap + + return r.HeaderMap } -func (rw *responseRecorderWithoutCloseNotify) GetCode() int { - return rw.Code +func (r *responseRecorderWithoutCloseNotify) GetCode() int { + return r.Code } -func (rw *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { - return rw.Body +func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { + return r.Body } -func (rw *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { - return rw.streamingResponseStarted +func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { + return r.streamingResponseStarted } // Write always succeeds and writes to rw.Body, if not nil. -func (rw *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { - if rw.err != nil { - return 0, rw.err +func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { + if r.err != nil { + return 0, r.err } - return rw.Body.Write(buf) + return r.Body.Write(buf) } // WriteHeader sets rw.Code. -func (rw *responseRecorderWithoutCloseNotify) WriteHeader(code int) { - rw.Code = code +func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) { + r.Code = code } // Hijack hijacks the connection -func (rw *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return rw.responseWriter.(http.Hijacker).Hijack() +func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return r.responseWriter.(http.Hijacker).Hijack() } // Flush sends any buffered data to the client. -func (rw *responseRecorderWithoutCloseNotify) Flush() { - if !rw.streamingResponseStarted { - utils.CopyHeaders(rw.responseWriter.Header(), rw.Header()) - rw.responseWriter.WriteHeader(rw.Code) - rw.streamingResponseStarted = true +func (r *responseRecorderWithoutCloseNotify) Flush() { + if !r.streamingResponseStarted { + utils.CopyHeaders(r.responseWriter.Header(), r.Header()) + r.responseWriter.WriteHeader(r.Code) + r.streamingResponseStarted = true } - _, err := rw.responseWriter.Write(rw.Body.Bytes()) + _, err := r.responseWriter.Write(r.Body.Bytes()) if err != nil { - log.Errorf("Error writing response in responseRecorder: %s", err) - rw.err = err + log.Errorf("Error writing response in responseRecorder: %v", err) + r.err = err } - rw.Body.Reset() + r.Body.Reset() - if flusher, ok := rw.responseWriter.(http.Flusher); ok { + if flusher, ok := r.responseWriter.(http.Flusher); ok { flusher.Flush() } } diff --git a/middlewares/errorpages/error_pages_test.go b/middlewares/errorpages/error_pages_test.go index 2264dc336..9cf19d87d 100644 --- a/middlewares/errorpages/error_pages_test.go +++ b/middlewares/errorpages/error_pages_test.go @@ -318,7 +318,6 @@ func TestHandlerOldWayIntegration(t *testing.T) { require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Foo", "bar") w.WriteHeader(test.backendCode) fmt.Fprintln(w, http.StatusText(test.backendCode)) }) @@ -331,7 +330,6 @@ func TestHandlerOldWayIntegration(t *testing.T) { n.ServeHTTP(recorder, req) test.validate(t, recorder) - assert.Equal(t, "bar", recorder.Header().Get("X-Foo"), "missing header") }) } }