Create a new capture instance for each incoming request

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Simon Delicata 2022-11-17 10:26:06 +01:00 committed by GitHub
parent 35d8281f4d
commit f1b91a119d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -39,11 +39,14 @@ type key string
const capturedData key = "capturedData" const capturedData key = "capturedData"
// Wrap returns a new handler that inserts a Capture into the given handler. // Wrap returns a new handler that inserts a Capture into the given handler for each incoming request.
// It satisfies the alice.Constructor type. // It satisfies the alice.Constructor type.
func Wrap(handler http.Handler) (http.Handler, error) { func Wrap(next http.Handler) (http.Handler, error) {
c := Capture{} return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
return c.Reset(handler), nil c := &Capture{}
newRW, newReq := c.renew(rw, req)
next.ServeHTTP(newRW, newReq)
}), nil
} }
// FromContext returns the Capture value found in ctx, or an empty Capture otherwise. // FromContext returns the Capture value found in ctx, or an empty Capture otherwise.
@ -68,6 +71,7 @@ type Capture struct {
// NeedsReset returns whether the given http.ResponseWriter is the capture's probe. // NeedsReset returns whether the given http.ResponseWriter is the capture's probe.
func (c *Capture) NeedsReset(rw http.ResponseWriter) bool { func (c *Capture) NeedsReset(rw http.ResponseWriter) bool {
// This comparison is naive.
return c.rw != rw return c.rw != rw
} }
@ -75,20 +79,25 @@ func (c *Capture) NeedsReset(rw http.ResponseWriter) bool {
// them when deferring to next. // them when deferring to next.
func (c *Capture) Reset(next http.Handler) http.Handler { func (c *Capture) Reset(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), capturedData, c) newRW, newReq := c.renew(rw, req)
newReq := req.WithContext(ctx) next.ServeHTTP(newRW, newReq)
if newReq.Body != nil {
readCounter := &readCounter{source: newReq.Body}
c.rr = readCounter
newReq.Body = readCounter
}
c.rw = newResponseWriter(rw)
next.ServeHTTP(c.rw, newReq)
}) })
} }
func (c *Capture) renew(rw http.ResponseWriter, req *http.Request) (http.ResponseWriter, *http.Request) {
ctx := context.WithValue(req.Context(), capturedData, c)
newReq := req.WithContext(ctx)
if newReq.Body != nil {
readCounter := &readCounter{source: newReq.Body}
c.rr = readCounter
newReq.Body = readCounter
}
c.rw = newResponseWriter(rw)
return c.rw, newReq
}
func (c *Capture) ResponseSize() int64 { func (c *Capture) ResponseSize() int64 {
return c.rw.Size() return c.rw.Size()
} }