Add *headers.responseModifier CloseNotify()

This commit is contained in:
Daniel Tomcej 2021-07-13 04:28:07 -06:00 committed by GitHub
parent 3072354ca5
commit 10ab39c33b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,8 +10,8 @@ import (
) )
type responseModifier struct { type responseModifier struct {
r *http.Request req *http.Request
w http.ResponseWriter rw http.ResponseWriter
headersSent bool // whether headers have already been sent headersSent bool // whether headers have already been sent
code int // status code, must default to 200 code int // status code, must default to 200
@ -24,71 +24,76 @@ type responseModifier struct {
// modifier can be nil. // modifier can be nil.
func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier { func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) *responseModifier {
return &responseModifier{ return &responseModifier{
r: r, req: r,
w: w, rw: w,
modifier: modifier, modifier: modifier,
code: http.StatusOK, code: http.StatusOK,
} }
} }
func (w *responseModifier) WriteHeader(code int) { func (r *responseModifier) WriteHeader(code int) {
if w.headersSent { if r.headersSent {
return return
} }
defer func() { defer func() {
w.code = code r.code = code
w.headersSent = true r.headersSent = true
}() }()
if w.modifier == nil || w.modified { if r.modifier == nil || r.modified {
w.w.WriteHeader(code) r.rw.WriteHeader(code)
return return
} }
resp := http.Response{ resp := http.Response{
Header: w.w.Header(), Header: r.rw.Header(),
Request: w.r, Request: r.req,
} }
if err := w.modifier(&resp); err != nil { if err := r.modifier(&resp); err != nil {
w.modifierErr = err r.modifierErr = err
// we are propagating when we are called in Write, but we're logging anyway, // we are propagating when we are called in Write, but we're logging anyway,
// because we could be called from another place which does not take care of // because we could be called from another place which does not take care of
// checking w.modifierErr. // checking w.modifierErr.
log.WithoutContext().Errorf("Error when applying response modifier: %v", err) log.WithoutContext().Errorf("Error when applying response modifier: %v", err)
w.w.WriteHeader(http.StatusInternalServerError) r.rw.WriteHeader(http.StatusInternalServerError)
return return
} }
w.modified = true r.modified = true
w.w.WriteHeader(code) r.rw.WriteHeader(code)
} }
func (w *responseModifier) Header() http.Header { func (r *responseModifier) Header() http.Header {
return w.w.Header() return r.rw.Header()
} }
func (w *responseModifier) Write(b []byte) (int, error) { func (r *responseModifier) Write(b []byte) (int, error) {
w.WriteHeader(w.code) r.WriteHeader(r.code)
if w.modifierErr != nil { if r.modifierErr != nil {
return 0, w.modifierErr return 0, r.modifierErr
} }
return w.w.Write(b) return r.rw.Write(b)
} }
// Hijack hijacks the connection. // Hijack hijacks the connection.
func (w *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (r *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.w.(http.Hijacker); ok { if h, ok := r.rw.(http.Hijacker); ok {
return h.Hijack() return h.Hijack()
} }
return nil, nil, fmt.Errorf("not a hijacker: %T", w.w) return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw)
} }
// Flush sends any buffered data to the client. // Flush sends any buffered data to the client.
func (w *responseModifier) Flush() { func (r *responseModifier) Flush() {
if flusher, ok := w.w.(http.Flusher); ok { if flusher, ok := r.rw.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }
} }
// CloseNotify implements http.CloseNotifier.
func (r *responseModifier) CloseNotify() <-chan bool {
return r.rw.(http.CloseNotifier).CloseNotify()
}