diff --git a/middlewares/tracing/entrypoint.go b/middlewares/tracing/entrypoint.go index c8049c758..eb2f6aca6 100644 --- a/middlewares/tracing/entrypoint.go +++ b/middlewares/tracing/entrypoint.go @@ -32,11 +32,11 @@ func (e *entryPointMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, LogRequest(span, r) ext.SpanKindRPCServer.Set(span) - w = &statusCodeTracker{w, 200} r = r.WithContext(opentracing.ContextWithSpan(r.Context(), span)) - next(w, r) + recorder := newStatusCodeRecoder(w, 200) + next(recorder, r) - LogResponseCode(span, w.(*statusCodeTracker).status) + LogResponseCode(span, recorder.Status()) span.Finish() } diff --git a/middlewares/tracing/forwarder.go b/middlewares/tracing/forwarder.go index 15445c9fe..881302cb4 100644 --- a/middlewares/tracing/forwarder.go +++ b/middlewares/tracing/forwarder.go @@ -38,9 +38,9 @@ func (f *forwarderMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, InjectRequestHeaders(r) - w = &statusCodeTracker{w, 200} + recorder := newStatusCodeRecoder(w, 200) - next(w, r) + next(recorder, r) - LogResponseCode(span, w.(*statusCodeTracker).status) + LogResponseCode(span, recorder.Status()) } diff --git a/middlewares/tracing/status_code.go b/middlewares/tracing/status_code.go new file mode 100644 index 000000000..ec1802467 --- /dev/null +++ b/middlewares/tracing/status_code.go @@ -0,0 +1,57 @@ +package tracing + +import ( + "bufio" + "net" + "net/http" +) + +type statusCodeRecoder interface { + http.ResponseWriter + Status() int +} + +type statusCodeWithoutCloseNotify struct { + http.ResponseWriter + status int +} + +// WriteHeader captures the status code for later retrieval. +func (s *statusCodeWithoutCloseNotify) WriteHeader(status int) { + s.status = status + s.ResponseWriter.WriteHeader(status) +} + +// Status get response status +func (s *statusCodeWithoutCloseNotify) Status() int { + return s.status +} + +// Hijack hijacks the connection +func (s *statusCodeWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return s.ResponseWriter.(http.Hijacker).Hijack() +} + +// Flush sends any buffered data to the client. +func (s *statusCodeWithoutCloseNotify) Flush() { + if flusher, ok := s.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +type statusCodeWithCloseNotify struct { + *statusCodeWithoutCloseNotify +} + +func (s *statusCodeWithCloseNotify) CloseNotify() <-chan bool { + return s.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + +// newStatusCodeRecoder returns an initialized statusCodeRecoder. +func newStatusCodeRecoder(rw http.ResponseWriter, status int) statusCodeRecoder { + recorder := &statusCodeWithoutCloseNotify{rw, status} + if _, ok := rw.(http.CloseNotifier); ok { + return &statusCodeWithCloseNotify{recorder} + } + return recorder +} diff --git a/middlewares/tracing/tracing.go b/middlewares/tracing/tracing.go index 0715010de..78195ad35 100644 --- a/middlewares/tracing/tracing.go +++ b/middlewares/tracing/tracing.go @@ -28,16 +28,6 @@ type Backend interface { Setup(serviceName string) (opentracing.Tracer, io.Closer, error) } -type statusCodeTracker struct { - http.ResponseWriter - status int -} - -func (s *statusCodeTracker) WriteHeader(status int) { - s.status = status - s.ResponseWriter.WriteHeader(status) -} - // Setup Tracing middleware func (t *Tracing) Setup() { var err error