diff --git a/pkg/middlewares/headers/responsewriter.go b/pkg/middlewares/headers/responsewriter.go index a50643eae..39a171dc2 100644 --- a/pkg/middlewares/headers/responsewriter.go +++ b/pkg/middlewares/headers/responsewriter.go @@ -10,8 +10,8 @@ import ( ) type responseModifier struct { - r *http.Request - w http.ResponseWriter + req *http.Request + rw http.ResponseWriter headersSent bool // whether headers have already been sent code int // status code, must default to 200 @@ -24,71 +24,76 @@ type responseModifier struct { // modifier can be nil. func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier { return &responseModifier{ - r: r, - w: w, + req: r, + rw: w, modifier: modifier, code: http.StatusOK, } } -func (w *responseModifier) WriteHeader(code int) { - if w.headersSent { +func (r *responseModifier) WriteHeader(code int) { + if r.headersSent { return } defer func() { - w.code = code - w.headersSent = true + r.code = code + r.headersSent = true }() - if w.modifier == nil || w.modified { - w.w.WriteHeader(code) + if r.modifier == nil || r.modified { + r.rw.WriteHeader(code) return } resp := http.Response{ - Header: w.w.Header(), - Request: w.r, + Header: r.rw.Header(), + Request: r.req, } - if err := w.modifier(&resp); err != nil { - w.modifierErr = err + if err := r.modifier(&resp); err != nil { + r.modifierErr = err // we are propagating when we are called in Write, but we're logging anyway, // because we could be called from another place which does not take care of // checking w.modifierErr. log.WithoutContext().Errorf("Error when applying response modifier: %v", err) - w.w.WriteHeader(http.StatusInternalServerError) + r.rw.WriteHeader(http.StatusInternalServerError) return } - w.modified = true - w.w.WriteHeader(code) + r.modified = true + r.rw.WriteHeader(code) } -func (w *responseModifier) Header() http.Header { - return w.w.Header() +func (r *responseModifier) Header() http.Header { + return r.rw.Header() } -func (w *responseModifier) Write(b []byte) (int, error) { - w.WriteHeader(w.code) - if w.modifierErr != nil { - return 0, w.modifierErr +func (r *responseModifier) Write(b []byte) (int, error) { + r.WriteHeader(r.code) + if r.modifierErr != nil { + return 0, r.modifierErr } - return w.w.Write(b) + return r.rw.Write(b) } // Hijack hijacks the connection. -func (w *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if h, ok := w.w.(http.Hijacker); ok { +func (r *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := r.rw.(http.Hijacker); ok { return h.Hijack() } - return nil, nil, fmt.Errorf("not a hijacker: %T", w.w) + return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw) } // Flush sends any buffered data to the client. -func (w *responseModifier) Flush() { - if flusher, ok := w.w.(http.Flusher); ok { +func (r *responseModifier) Flush() { + if flusher, ok := r.rw.(http.Flusher); ok { flusher.Flush() } } + +// CloseNotify implements http.CloseNotifier. +func (r *responseModifier) CloseNotify() <-chan bool { + return r.rw.(http.CloseNotifier).CloseNotify() +}