From 61e59d74e0c4165efb7891a40c07652cba8d4dfe Mon Sep 17 00:00:00 2001 From: mpl Date: Thu, 12 Dec 2019 15:12:05 +0100 Subject: [PATCH] CloseNotifier: return pointer instead of value --- .../accesslog/capture_response_writer.go | 11 +--- .../accesslog/capture_response_writer_test.go | 50 +++++++++++++++++++ pkg/middlewares/metrics/metrics_test.go | 42 ++++++++++++++++ pkg/middlewares/metrics/recorder.go | 9 +--- 4 files changed, 95 insertions(+), 17 deletions(-) create mode 100644 pkg/middlewares/accesslog/capture_response_writer_test.go diff --git a/pkg/middlewares/accesslog/capture_response_writer.go b/pkg/middlewares/accesslog/capture_response_writer.go index da4991a72..4202b0ee9 100644 --- a/pkg/middlewares/accesslog/capture_response_writer.go +++ b/pkg/middlewares/accesslog/capture_response_writer.go @@ -10,7 +10,7 @@ import ( ) var ( - _ middlewares.Stateful = &captureResponseWriter{} + _ middlewares.Stateful = &captureResponseWriterWithCloseNotify{} ) type capturer interface { @@ -24,7 +24,7 @@ func newCaptureResponseWriter(rw http.ResponseWriter) capturer { if _, ok := rw.(http.CloseNotifier); !ok { return capt } - return captureResponseWriterWithCloseNotify{capt} + return &captureResponseWriterWithCloseNotify{capt} } // captureResponseWriter is a wrapper of type http.ResponseWriter @@ -76,13 +76,6 @@ func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw) } -func (crw *captureResponseWriter) CloseNotify() <-chan bool { - if c, ok := crw.rw.(http.CloseNotifier); ok { - return c.CloseNotify() - } - return nil -} - func (crw *captureResponseWriter) Status() int { return crw.status } diff --git a/pkg/middlewares/accesslog/capture_response_writer_test.go b/pkg/middlewares/accesslog/capture_response_writer_test.go new file mode 100644 index 000000000..3606fc033 --- /dev/null +++ b/pkg/middlewares/accesslog/capture_response_writer_test.go @@ -0,0 +1,50 @@ +package accesslog + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +type rwWithCloseNotify struct { + *httptest.ResponseRecorder +} + +func (r *rwWithCloseNotify) CloseNotify() <-chan bool { + panic("implement me") +} + +func TestCloseNotifier(t *testing.T) { + testCases := []struct { + rw http.ResponseWriter + desc string + implementsCloseNotifier bool + }{ + { + rw: httptest.NewRecorder(), + desc: "does not implement CloseNotifier", + implementsCloseNotifier: false, + }, + { + rw: &rwWithCloseNotify{httptest.NewRecorder()}, + desc: "implements CloseNotifier", + implementsCloseNotifier: true, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + _, ok := test.rw.(http.CloseNotifier) + assert.Equal(t, test.implementsCloseNotifier, ok) + + rw := newCaptureResponseWriter(test.rw) + _, impl := rw.(http.CloseNotifier) + assert.Equal(t, test.implementsCloseNotifier, impl) + }) + } +} diff --git a/pkg/middlewares/metrics/metrics_test.go b/pkg/middlewares/metrics/metrics_test.go index 9df351649..596ac6abe 100644 --- a/pkg/middlewares/metrics/metrics_test.go +++ b/pkg/middlewares/metrics/metrics_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/go-kit/kit/metrics" + "github.com/stretchr/testify/assert" ) // CollectingCounter is a metrics.Counter implementation that enables access to the CounterValue and LastLabelValues. @@ -56,3 +57,44 @@ func newCollectingRetryMetrics() *collectingRetryMetrics { func (m *collectingRetryMetrics) ServiceRetriesCounter() metrics.Counter { return m.retriesCounter } + +type rwWithCloseNotify struct { + *httptest.ResponseRecorder +} + +func (r *rwWithCloseNotify) CloseNotify() <-chan bool { + panic("implement me") +} + +func TestCloseNotifier(t *testing.T) { + testCases := []struct { + rw http.ResponseWriter + desc string + implementsCloseNotifier bool + }{ + { + rw: httptest.NewRecorder(), + desc: "does not implement CloseNotifier", + implementsCloseNotifier: false, + }, + { + rw: &rwWithCloseNotify{httptest.NewRecorder()}, + desc: "implements CloseNotifier", + implementsCloseNotifier: true, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + _, ok := test.rw.(http.CloseNotifier) + assert.Equal(t, test.implementsCloseNotifier, ok) + + rw := newResponseRecorder(test.rw) + _, impl := rw.(http.CloseNotifier) + assert.Equal(t, test.implementsCloseNotifier, impl) + }) + } +} diff --git a/pkg/middlewares/metrics/recorder.go b/pkg/middlewares/metrics/recorder.go index 4206558c7..b39a79954 100644 --- a/pkg/middlewares/metrics/recorder.go +++ b/pkg/middlewares/metrics/recorder.go @@ -20,7 +20,7 @@ func newResponseRecorder(rw http.ResponseWriter) recorder { if _, ok := rw.(http.CloseNotifier); !ok { return rec } - return responseRecorderWithCloseNotify{rec} + return &responseRecorderWithCloseNotify{rec} } // responseRecorder captures information from the response and preserves it for @@ -55,13 +55,6 @@ func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return r.ResponseWriter.(http.Hijacker).Hijack() } -// CloseNotify returns a channel that receives at most a -// single value (true) when the client connection has gone -// away. -func (r *responseRecorder) CloseNotify() <-chan bool { - return r.ResponseWriter.(http.CloseNotifier).CloseNotify() -} - // Flush sends any buffered data to the client. func (r *responseRecorder) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok {