error pages: do not buffer response when it's not an error

This commit is contained in:
mpl 2019-09-12 16:20:05 +02:00 committed by Traefiker Bot
parent 743d772a80
commit a239e3fba6
2 changed files with 169 additions and 17 deletions

View file

@ -22,7 +22,10 @@ import (
) )
// Compile time validation that the response recorder implements http interfaces correctly. // Compile time validation that the response recorder implements http interfaces correctly.
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{} var (
_ middlewares.Stateful = &responseRecorderWithCloseNotify{}
_ middlewares.Stateful = &codeCatcherWithCloseNotify{}
)
const ( const (
typeName = "customError" typeName = "customError"
@ -80,25 +83,29 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return return
} }
recorder := newResponseRecorder(ctx, rw) catcher := newCodeCatcher(rw, c.httpCodeRanges)
c.next.ServeHTTP(recorder, req) c.next.ServeHTTP(catcher, req)
if !catcher.isFilteredCode() {
return
}
// check the recorder code against the configured http status code ranges // check the recorder code against the configured http status code ranges
code := catcher.getCode()
for _, block := range c.httpCodeRanges { for _, block := range c.httpCodeRanges {
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] { if code >= block[0] && code <= block[1] {
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode()) logger.Errorf("Caught HTTP Status Code %d, returning error page", code)
var query string var query string
if len(c.backendQuery) > 0 { if len(c.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(c.backendQuery, "/") query = "/" + strings.TrimPrefix(c.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1) query = strings.Replace(query, "{status}", strconv.Itoa(code), -1)
} }
pageReq, err := newRequest(backendURL + query) pageReq, err := newRequest(backendURL + query)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
rw.WriteHeader(recorder.GetCode()) rw.WriteHeader(code)
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode())) _, err = fmt.Fprint(rw, http.StatusText(code))
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
} }
@ -111,7 +118,7 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header()) utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
rw.WriteHeader(recorder.GetCode()) rw.WriteHeader(code)
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil { if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
logger.Error(err) logger.Error(err)
@ -119,14 +126,6 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return return
} }
} }
// did not catch a configured status code so proceed with the request
utils.CopyHeaders(rw.Header(), recorder.Header())
rw.WriteHeader(recorder.GetCode())
_, err := rw.Write(recorder.GetBody().Bytes())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
} }
func newRequest(baseURL string) (*http.Request, error) { func newRequest(baseURL string) (*http.Request, error) {
@ -144,6 +143,132 @@ func newRequest(baseURL string) (*http.Request, error) {
return req, nil return req, nil
} }
type responseInterceptor interface {
http.ResponseWriter
http.Flusher
getCode() int
isFilteredCode() bool
}
// codeCatcher is a response writer that detects as soon as possible whether the
// response is a code within the ranges of codes it watches for. If it is, it
// simply drops the data from the response. Otherwise, it forwards it directly to
// the original client (its responseWriter) without any buffering.
type codeCatcher struct {
headerMap http.Header
code int
httpCodeRanges types.HTTPCodeRanges
firstWrite bool
caughtFilteredCode bool
responseWriter http.ResponseWriter
headersSent bool
}
type codeCatcherWithCloseNotify struct {
*codeCatcher
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away.
func (cc *codeCatcherWithCloseNotify) CloseNotify() <-chan bool {
return cc.responseWriter.(http.CloseNotifier).CloseNotify()
}
func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) responseInterceptor {
catcher := &codeCatcher{
headerMap: make(http.Header),
code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200.
responseWriter: rw,
httpCodeRanges: httpCodeRanges,
firstWrite: true,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &codeCatcherWithCloseNotify{catcher}
}
return catcher
}
func (cc *codeCatcher) Header() http.Header {
if cc.headerMap == nil {
cc.headerMap = make(http.Header)
}
return cc.headerMap
}
func (cc *codeCatcher) getCode() int {
return cc.code
}
// isFilteredCode returns whether the codeCatcher received a response code among the ones it is watching,
// and for which the response should be deferred to the error handler.
func (cc *codeCatcher) isFilteredCode() bool {
return cc.caughtFilteredCode
}
func (cc *codeCatcher) Write(buf []byte) (int, error) {
if !cc.firstWrite {
if cc.caughtFilteredCode {
// We don't care about the contents of the response,
// since we want to serve the ones from the error page,
// so we just drop them.
return len(buf), nil
}
return cc.responseWriter.Write(buf)
}
cc.firstWrite = false
// If WriteHeader was already called from the caller, this is a NOOP.
// Otherwise, cc.code is actually a 200 here.
cc.WriteHeader(cc.code)
if cc.caughtFilteredCode {
return len(buf), nil
}
return cc.responseWriter.Write(buf)
}
func (cc *codeCatcher) WriteHeader(code int) {
if cc.headersSent || cc.caughtFilteredCode {
return
}
cc.code = code
for _, block := range cc.httpCodeRanges {
if cc.code >= block[0] && cc.code <= block[1] {
cc.caughtFilteredCode = true
break
}
}
// it will be up to the other response recorder to send the headers,
// so it is out of our hands now.
if cc.caughtFilteredCode {
return
}
utils.CopyHeaders(cc.responseWriter.Header(), cc.Header())
cc.responseWriter.WriteHeader(cc.code)
cc.headersSent = true
}
// Hijack hijacks the connection
func (cc *codeCatcher) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := cc.responseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", cc.responseWriter)
}
// Flush sends any buffered data to the client.
func (cc *codeCatcher) Flush() {
// If WriteHeader was already called from the caller, this is a NOOP.
// Otherwise, cc.code is actually a 200 here.
cc.WriteHeader(cc.code)
if flusher, ok := cc.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type responseRecorder interface { type responseRecorder interface {
http.ResponseWriter http.ResponseWriter
http.Flusher http.Flusher

View file

@ -33,6 +33,30 @@ func TestHandler(t *testing.T) {
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK)) assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
}, },
}, },
{
desc: "no error, but not a 200",
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusPartialContent,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusPartialContent, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent))
},
},
{
desc: "a 304, so no Write called",
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusNotModified,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "whatever, should not be called")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusNotModified, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "")
},
},
{ {
desc: "in the range", desc: "in the range",
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}}, errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
@ -104,6 +128,9 @@ func TestHandler(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode) w.WriteHeader(test.backendCode)
if test.backendCode == http.StatusNotModified {
return
}
fmt.Fprintln(w, http.StatusText(test.backendCode)) fmt.Fprintln(w, http.StatusText(test.backendCode))
}) })
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test") errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")