From 5b24403c8ed815c8f702c07768aba86ba94d5938 Mon Sep 17 00:00:00 2001 From: SALLEYRON Julien Date: Thu, 4 Jan 2018 11:18:03 +0100 Subject: [PATCH] Don't panic if ResponseWriter does not implement CloseNotify --- middlewares/error_pages.go | 16 ++++---- middlewares/retry.go | 82 +++++++++++++++++++++++++------------- middlewares/retry_test.go | 50 +++++++++++++++++++++++ 3 files changed, 113 insertions(+), 35 deletions(-) diff --git a/middlewares/error_pages.go b/middlewares/error_pages.go index 3ef49f11d..086192194 100644 --- a/middlewares/error_pages.go +++ b/middlewares/error_pages.go @@ -52,18 +52,18 @@ func NewErrorPagesHandler(errorPage types.ErrorPage, backendURL string) (*ErrorP } func (ep *ErrorPagesHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) { - recorder := newRetryResponseRecorder() - recorder.responseWriter = w + recorder := newRetryResponseRecorder(w) + next.ServeHTTP(recorder, req) - w.WriteHeader(recorder.Code) + w.WriteHeader(recorder.GetCode()) //check the recorder code against the configured http status code ranges for _, block := range ep.HTTPCodeRanges { - if recorder.Code >= block[0] && recorder.Code <= block[1] { - log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.Code) - finalURL := strings.Replace(ep.BackendURL, "{status}", strconv.Itoa(recorder.Code), -1) + if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { + log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode()) + finalURL := strings.Replace(ep.BackendURL, "{status}", strconv.Itoa(recorder.GetCode()), -1) if newReq, err := http.NewRequest(http.MethodGet, finalURL, nil); err != nil { - w.Write([]byte(http.StatusText(recorder.Code))) + w.Write([]byte(http.StatusText(recorder.GetCode()))) } else { ep.errorPageForwarder.ServeHTTP(w, newReq) } @@ -73,5 +73,5 @@ func (ep *ErrorPagesHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, //did not catch a configured status code so proceed with the request utils.CopyHeaders(w.Header(), recorder.Header()) - w.Write(recorder.Body.Bytes()) + w.Write(recorder.GetBody().Bytes()) } diff --git a/middlewares/retry.go b/middlewares/retry.go index eee243a29..9b57ed0a6 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -14,7 +14,7 @@ import ( // Compile time validation responseRecorder implements http interfaces correctly. var ( - _ Stateful = &retryResponseRecorder{} + _ Stateful = &retryResponseRecorderWithCloseNotify{} ) // Retry is a middleware that retries requests @@ -48,22 +48,21 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler. newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred) - recorder := newRetryResponseRecorder() - recorder.responseWriter = rw + recorder := newRetryResponseRecorder(rw) retry.next.ServeHTTP(recorder, r.WithContext(newCtx)) // It's a stream request and the body gets already sent to the client. // Therefore we should not send the response a second time. - if recorder.streamingResponseStarted { + if recorder.IsStreamingResponseStarted() { recorder.Flush() break } if !netErrorOccurred || attempts >= retry.attempts { utils.CopyHeaders(rw.Header(), recorder.Header()) - rw.WriteHeader(recorder.Code) - rw.Write(recorder.Body.Bytes()) + rw.WriteHeader(recorder.GetCode()) + rw.Write(recorder.GetBody().Bytes()) break } attempts++ @@ -115,9 +114,31 @@ func (l RetryListeners) Retried(req *http.Request, attempt int) { } } -// retryResponseRecorder is an implementation of http.ResponseWriter that +type retryResponseRecorder interface { + http.ResponseWriter + http.Flusher + GetCode() int + GetBody() *bytes.Buffer + IsStreamingResponseStarted() bool +} + +// newRetryResponseRecorder returns an initialized retryResponseRecorder. +func newRetryResponseRecorder(rw http.ResponseWriter) retryResponseRecorder { + recorder := &retryResponseRecorderWithoutCloseNotify{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: http.StatusOK, + responseWriter: rw, + } + if _, ok := rw.(http.CloseNotifier); ok { + return &retryResponseRecorderWithCloseNotify{recorder} + } + return recorder +} + +// retryResponseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that // records its mutations for later inspection. -type retryResponseRecorder struct { +type retryResponseRecorderWithoutCloseNotify struct { Code int // the HTTP response code from WriteHeader HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to @@ -127,17 +148,19 @@ type retryResponseRecorder struct { streamingResponseStarted bool } -// newRetryResponseRecorder returns an initialized retryResponseRecorder. -func newRetryResponseRecorder() *retryResponseRecorder { - return &retryResponseRecorder{ - HeaderMap: make(http.Header), - Body: new(bytes.Buffer), - Code: http.StatusOK, - } +type retryResponseRecorderWithCloseNotify struct { + *retryResponseRecorderWithoutCloseNotify +} + +// CloseNotify returns a channel that receives at most a +// single value (true) when the client connection has gone +// away. +func (rw *retryResponseRecorderWithCloseNotify) CloseNotify() <-chan bool { + return rw.responseWriter.(http.CloseNotifier).CloseNotify() } // Header returns the response headers. -func (rw *retryResponseRecorder) Header() http.Header { +func (rw *retryResponseRecorderWithoutCloseNotify) Header() http.Header { m := rw.HeaderMap if m == nil { m = make(http.Header) @@ -146,8 +169,20 @@ func (rw *retryResponseRecorder) Header() http.Header { return m } +func (rw *retryResponseRecorderWithoutCloseNotify) GetCode() int { + return rw.Code +} + +func (rw *retryResponseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { + return rw.Body +} + +func (rw *retryResponseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { + return rw.streamingResponseStarted +} + // Write always succeeds and writes to rw.Body, if not nil. -func (rw *retryResponseRecorder) Write(buf []byte) (int, error) { +func (rw *retryResponseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { if rw.err != nil { return 0, rw.err } @@ -155,24 +190,17 @@ func (rw *retryResponseRecorder) Write(buf []byte) (int, error) { } // WriteHeader sets rw.Code. -func (rw *retryResponseRecorder) WriteHeader(code int) { +func (rw *retryResponseRecorderWithoutCloseNotify) WriteHeader(code int) { rw.Code = code } // Hijack hijacks the connection -func (rw *retryResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (rw *retryResponseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { return rw.responseWriter.(http.Hijacker).Hijack() } -// CloseNotify returns a channel that receives at most a -// single value (true) when the client connection has gone -// away. -func (rw *retryResponseRecorder) CloseNotify() <-chan bool { - return rw.responseWriter.(http.CloseNotifier).CloseNotify() -} - // Flush sends any buffered data to the client. -func (rw *retryResponseRecorder) Flush() { +func (rw *retryResponseRecorderWithoutCloseNotify) Flush() { if !rw.streamingResponseStarted { utils.CopyHeaders(rw.responseWriter.Header(), rw.Header()) rw.responseWriter.WriteHeader(rw.Code) diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go index 4f74efba7..a8ab483fd 100644 --- a/middlewares/retry_test.go +++ b/middlewares/retry_test.go @@ -7,6 +7,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/assert" ) func TestRetry(t *testing.T) { @@ -152,3 +154,51 @@ func TestRetryWithFlush(t *testing.T) { t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA") } } + +func TestNewRetryResponseRecorder(t *testing.T) { + testCases := []struct { + desc string + rw http.ResponseWriter + expected http.ResponseWriter + }{ + { + desc: "Without Close Notify", + rw: httptest.NewRecorder(), + expected: &retryResponseRecorderWithoutCloseNotify{}, + }, + { + desc: "With Close Notify", + rw: &mockRWCloseNotify{}, + expected: &retryResponseRecorderWithCloseNotify{}, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rec := newRetryResponseRecorder(test.rw) + + assert.IsType(t, rec, test.expected) + }) + } +} + +type mockRWCloseNotify struct{} + +func (m *mockRWCloseNotify) CloseNotify() <-chan bool { + panic("implement me") +} + +func (m *mockRWCloseNotify) Header() http.Header { + panic("implement me") +} + +func (m *mockRWCloseNotify) Write([]byte) (int, error) { + panic("implement me") +} + +func (m *mockRWCloseNotify) WriteHeader(int) { + panic("implement me") +}