From f1b91a119d2760f9008f9ce9165490040b24ce7e Mon Sep 17 00:00:00 2001 From: Simon Delicata Date: Thu, 17 Nov 2022 10:26:06 +0100 Subject: [PATCH] Create a new capture instance for each incoming request Co-authored-by: Romain --- pkg/middlewares/capture/capture.go | 39 ++++++++++++++++++------------ 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/pkg/middlewares/capture/capture.go b/pkg/middlewares/capture/capture.go index 845c2e983..c541cde29 100644 --- a/pkg/middlewares/capture/capture.go +++ b/pkg/middlewares/capture/capture.go @@ -39,11 +39,14 @@ type key string const capturedData key = "capturedData" -// Wrap returns a new handler that inserts a Capture into the given handler. +// Wrap returns a new handler that inserts a Capture into the given handler for each incoming request. // It satisfies the alice.Constructor type. -func Wrap(handler http.Handler) (http.Handler, error) { - c := Capture{} - return c.Reset(handler), nil +func Wrap(next http.Handler) (http.Handler, error) { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + c := &Capture{} + newRW, newReq := c.renew(rw, req) + next.ServeHTTP(newRW, newReq) + }), nil } // FromContext returns the Capture value found in ctx, or an empty Capture otherwise. @@ -68,6 +71,7 @@ type Capture struct { // NeedsReset returns whether the given http.ResponseWriter is the capture's probe. func (c *Capture) NeedsReset(rw http.ResponseWriter) bool { + // This comparison is naive. return c.rw != rw } @@ -75,20 +79,25 @@ func (c *Capture) NeedsReset(rw http.ResponseWriter) bool { // them when deferring to next. func (c *Capture) Reset(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - ctx := context.WithValue(req.Context(), capturedData, c) - newReq := req.WithContext(ctx) - - if newReq.Body != nil { - readCounter := &readCounter{source: newReq.Body} - c.rr = readCounter - newReq.Body = readCounter - } - c.rw = newResponseWriter(rw) - - next.ServeHTTP(c.rw, newReq) + newRW, newReq := c.renew(rw, req) + next.ServeHTTP(newRW, newReq) }) } +func (c *Capture) renew(rw http.ResponseWriter, req *http.Request) (http.ResponseWriter, *http.Request) { + ctx := context.WithValue(req.Context(), capturedData, c) + newReq := req.WithContext(ctx) + + if newReq.Body != nil { + readCounter := &readCounter{source: newReq.Body} + c.rr = readCounter + newReq.Body = readCounter + } + c.rw = newResponseWriter(rw) + + return c.rw, newReq +} + func (c *Capture) ResponseSize() int64 { return c.rw.Size() }