diff --git a/pkg/middlewares/customerrors/custom_errors.go b/pkg/middlewares/customerrors/custom_errors.go index 72c70a631..ee50f6dda 100644 --- a/pkg/middlewares/customerrors/custom_errors.go +++ b/pkg/middlewares/customerrors/custom_errors.go @@ -22,7 +22,10 @@ import ( ) // Compile time validation that the response recorder implements http interfaces correctly. -var _ middlewares.Stateful = &responseRecorderWithCloseNotify{} +var ( + _ middlewares.Stateful = &responseRecorderWithCloseNotify{} + _ middlewares.Stateful = &codeCatcherWithCloseNotify{} +) const ( typeName = "customError" @@ -80,25 +83,29 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - recorder := newResponseRecorder(ctx, rw) - c.next.ServeHTTP(recorder, req) + catcher := newCodeCatcher(rw, c.httpCodeRanges) + c.next.ServeHTTP(catcher, req) + if !catcher.isFilteredCode() { + return + } // check the recorder code against the configured http status code ranges + code := catcher.getCode() for _, block := range c.httpCodeRanges { - if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { - logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode()) + if code >= block[0] && code <= block[1] { + logger.Errorf("Caught HTTP Status Code %d, returning error page", code) var query string if len(c.backendQuery) > 0 { query = "/" + strings.TrimPrefix(c.backendQuery, "/") - query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1) + query = strings.Replace(query, "{status}", strconv.Itoa(code), -1) } pageReq, err := newRequest(backendURL + query) if err != nil { logger.Error(err) - rw.WriteHeader(recorder.GetCode()) - _, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode())) + rw.WriteHeader(code) + _, err = fmt.Fprint(rw, http.StatusText(code)) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) } @@ -111,7 +118,7 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) { c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) utils.CopyHeaders(rw.Header(), recorderErrorPage.Header()) - rw.WriteHeader(recorder.GetCode()) + rw.WriteHeader(code) if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil { logger.Error(err) @@ -119,14 +126,6 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } } - - // did not catch a configured status code so proceed with the request - utils.CopyHeaders(rw.Header(), recorder.Header()) - rw.WriteHeader(recorder.GetCode()) - _, err := rw.Write(recorder.GetBody().Bytes()) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } } func newRequest(baseURL string) (*http.Request, error) { @@ -144,6 +143,132 @@ func newRequest(baseURL string) (*http.Request, error) { return req, nil } +type responseInterceptor interface { + http.ResponseWriter + http.Flusher + getCode() int + isFilteredCode() bool +} + +// codeCatcher is a response writer that detects as soon as possible whether the +// response is a code within the ranges of codes it watches for. If it is, it +// simply drops the data from the response. Otherwise, it forwards it directly to +// the original client (its responseWriter) without any buffering. +type codeCatcher struct { + headerMap http.Header + code int + httpCodeRanges types.HTTPCodeRanges + firstWrite bool + caughtFilteredCode bool + responseWriter http.ResponseWriter + headersSent bool +} + +type codeCatcherWithCloseNotify struct { + *codeCatcher +} + +// CloseNotify returns a channel that receives at most a +// single value (true) when the client connection has gone away. +func (cc *codeCatcherWithCloseNotify) CloseNotify() <-chan bool { + return cc.responseWriter.(http.CloseNotifier).CloseNotify() +} + +func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) responseInterceptor { + catcher := &codeCatcher{ + headerMap: make(http.Header), + code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200. + responseWriter: rw, + httpCodeRanges: httpCodeRanges, + firstWrite: true, + } + if _, ok := rw.(http.CloseNotifier); ok { + return &codeCatcherWithCloseNotify{catcher} + } + return catcher +} + +func (cc *codeCatcher) Header() http.Header { + if cc.headerMap == nil { + cc.headerMap = make(http.Header) + } + + return cc.headerMap +} + +func (cc *codeCatcher) getCode() int { + return cc.code +} + +// isFilteredCode returns whether the codeCatcher received a response code among the ones it is watching, +// and for which the response should be deferred to the error handler. +func (cc *codeCatcher) isFilteredCode() bool { + return cc.caughtFilteredCode +} + +func (cc *codeCatcher) Write(buf []byte) (int, error) { + if !cc.firstWrite { + if cc.caughtFilteredCode { + // We don't care about the contents of the response, + // since we want to serve the ones from the error page, + // so we just drop them. + return len(buf), nil + } + return cc.responseWriter.Write(buf) + } + cc.firstWrite = false + + // If WriteHeader was already called from the caller, this is a NOOP. + // Otherwise, cc.code is actually a 200 here. + cc.WriteHeader(cc.code) + + if cc.caughtFilteredCode { + return len(buf), nil + } + return cc.responseWriter.Write(buf) +} + +func (cc *codeCatcher) WriteHeader(code int) { + if cc.headersSent || cc.caughtFilteredCode { + return + } + + cc.code = code + for _, block := range cc.httpCodeRanges { + if cc.code >= block[0] && cc.code <= block[1] { + cc.caughtFilteredCode = true + break + } + } + // it will be up to the other response recorder to send the headers, + // so it is out of our hands now. + if cc.caughtFilteredCode { + return + } + utils.CopyHeaders(cc.responseWriter.Header(), cc.Header()) + cc.responseWriter.WriteHeader(cc.code) + cc.headersSent = true +} + +// Hijack hijacks the connection +func (cc *codeCatcher) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := cc.responseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, fmt.Errorf("%T is not a http.Hijacker", cc.responseWriter) +} + +// Flush sends any buffered data to the client. +func (cc *codeCatcher) Flush() { + // If WriteHeader was already called from the caller, this is a NOOP. + // Otherwise, cc.code is actually a 200 here. + cc.WriteHeader(cc.code) + + if flusher, ok := cc.responseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + type responseRecorder interface { http.ResponseWriter http.Flusher diff --git a/pkg/middlewares/customerrors/custom_errors_test.go b/pkg/middlewares/customerrors/custom_errors_test.go index 5dbb22f30..58a3a1673 100644 --- a/pkg/middlewares/customerrors/custom_errors_test.go +++ b/pkg/middlewares/customerrors/custom_errors_test.go @@ -33,6 +33,30 @@ func TestHandler(t *testing.T) { assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK)) }, }, + { + desc: "no error, but not a 200", + errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusPartialContent, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "My error page.") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusPartialContent, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent)) + }, + }, + { + desc: "a 304, so no Write called", + errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, + backendCode: http.StatusNotModified, + backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "whatever, should not be called") + }), + validate: func(t *testing.T, recorder *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusNotModified, recorder.Code, "HTTP status") + assert.Contains(t, recorder.Body.String(), "") + }, + }, { desc: "in the range", errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, @@ -104,6 +128,9 @@ func TestHandler(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(test.backendCode) + if test.backendCode == http.StatusNotModified { + return + } fmt.Fprintln(w, http.StatusText(test.backendCode)) }) errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")