error pages: do not buffer response when it's not an error
This commit is contained in:
parent
743d772a80
commit
a239e3fba6
2 changed files with 169 additions and 17 deletions
|
@ -22,7 +22,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Compile time validation that the response recorder implements http interfaces correctly.
|
// Compile time validation that the response recorder implements http interfaces correctly.
|
||||||
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
|
var (
|
||||||
|
_ middlewares.Stateful = &responseRecorderWithCloseNotify{}
|
||||||
|
_ middlewares.Stateful = &codeCatcherWithCloseNotify{}
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
typeName = "customError"
|
typeName = "customError"
|
||||||
|
@ -80,25 +83,29 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
recorder := newResponseRecorder(ctx, rw)
|
catcher := newCodeCatcher(rw, c.httpCodeRanges)
|
||||||
c.next.ServeHTTP(recorder, req)
|
c.next.ServeHTTP(catcher, req)
|
||||||
|
if !catcher.isFilteredCode() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// check the recorder code against the configured http status code ranges
|
// check the recorder code against the configured http status code ranges
|
||||||
|
code := catcher.getCode()
|
||||||
for _, block := range c.httpCodeRanges {
|
for _, block := range c.httpCodeRanges {
|
||||||
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
|
if code >= block[0] && code <= block[1] {
|
||||||
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
|
logger.Errorf("Caught HTTP Status Code %d, returning error page", code)
|
||||||
|
|
||||||
var query string
|
var query string
|
||||||
if len(c.backendQuery) > 0 {
|
if len(c.backendQuery) > 0 {
|
||||||
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
|
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
|
||||||
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
|
query = strings.Replace(query, "{status}", strconv.Itoa(code), -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pageReq, err := newRequest(backendURL + query)
|
pageReq, err := newRequest(backendURL + query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
rw.WriteHeader(recorder.GetCode())
|
rw.WriteHeader(code)
|
||||||
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode()))
|
_, err = fmt.Fprint(rw, http.StatusText(code))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +118,7 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
||||||
|
|
||||||
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
|
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
|
||||||
rw.WriteHeader(recorder.GetCode())
|
rw.WriteHeader(code)
|
||||||
|
|
||||||
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
|
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
|
@ -119,14 +126,6 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// did not catch a configured status code so proceed with the request
|
|
||||||
utils.CopyHeaders(rw.Header(), recorder.Header())
|
|
||||||
rw.WriteHeader(recorder.GetCode())
|
|
||||||
_, err := rw.Write(recorder.GetBody().Bytes())
|
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequest(baseURL string) (*http.Request, error) {
|
func newRequest(baseURL string) (*http.Request, error) {
|
||||||
|
@ -144,6 +143,132 @@ func newRequest(baseURL string) (*http.Request, error) {
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type responseInterceptor interface {
|
||||||
|
http.ResponseWriter
|
||||||
|
http.Flusher
|
||||||
|
getCode() int
|
||||||
|
isFilteredCode() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// codeCatcher is a response writer that detects as soon as possible whether the
|
||||||
|
// response is a code within the ranges of codes it watches for. If it is, it
|
||||||
|
// simply drops the data from the response. Otherwise, it forwards it directly to
|
||||||
|
// the original client (its responseWriter) without any buffering.
|
||||||
|
type codeCatcher struct {
|
||||||
|
headerMap http.Header
|
||||||
|
code int
|
||||||
|
httpCodeRanges types.HTTPCodeRanges
|
||||||
|
firstWrite bool
|
||||||
|
caughtFilteredCode bool
|
||||||
|
responseWriter http.ResponseWriter
|
||||||
|
headersSent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type codeCatcherWithCloseNotify struct {
|
||||||
|
*codeCatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseNotify returns a channel that receives at most a
|
||||||
|
// single value (true) when the client connection has gone away.
|
||||||
|
func (cc *codeCatcherWithCloseNotify) CloseNotify() <-chan bool {
|
||||||
|
return cc.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCodeCatcher(rw http.ResponseWriter, httpCodeRanges types.HTTPCodeRanges) responseInterceptor {
|
||||||
|
catcher := &codeCatcher{
|
||||||
|
headerMap: make(http.Header),
|
||||||
|
code: http.StatusOK, // If backend does not call WriteHeader on us, we consider it's a 200.
|
||||||
|
responseWriter: rw,
|
||||||
|
httpCodeRanges: httpCodeRanges,
|
||||||
|
firstWrite: true,
|
||||||
|
}
|
||||||
|
if _, ok := rw.(http.CloseNotifier); ok {
|
||||||
|
return &codeCatcherWithCloseNotify{catcher}
|
||||||
|
}
|
||||||
|
return catcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *codeCatcher) Header() http.Header {
|
||||||
|
if cc.headerMap == nil {
|
||||||
|
cc.headerMap = make(http.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cc.headerMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *codeCatcher) getCode() int {
|
||||||
|
return cc.code
|
||||||
|
}
|
||||||
|
|
||||||
|
// isFilteredCode returns whether the codeCatcher received a response code among the ones it is watching,
|
||||||
|
// and for which the response should be deferred to the error handler.
|
||||||
|
func (cc *codeCatcher) isFilteredCode() bool {
|
||||||
|
return cc.caughtFilteredCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *codeCatcher) Write(buf []byte) (int, error) {
|
||||||
|
if !cc.firstWrite {
|
||||||
|
if cc.caughtFilteredCode {
|
||||||
|
// We don't care about the contents of the response,
|
||||||
|
// since we want to serve the ones from the error page,
|
||||||
|
// so we just drop them.
|
||||||
|
return len(buf), nil
|
||||||
|
}
|
||||||
|
return cc.responseWriter.Write(buf)
|
||||||
|
}
|
||||||
|
cc.firstWrite = false
|
||||||
|
|
||||||
|
// If WriteHeader was already called from the caller, this is a NOOP.
|
||||||
|
// Otherwise, cc.code is actually a 200 here.
|
||||||
|
cc.WriteHeader(cc.code)
|
||||||
|
|
||||||
|
if cc.caughtFilteredCode {
|
||||||
|
return len(buf), nil
|
||||||
|
}
|
||||||
|
return cc.responseWriter.Write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *codeCatcher) WriteHeader(code int) {
|
||||||
|
if cc.headersSent || cc.caughtFilteredCode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cc.code = code
|
||||||
|
for _, block := range cc.httpCodeRanges {
|
||||||
|
if cc.code >= block[0] && cc.code <= block[1] {
|
||||||
|
cc.caughtFilteredCode = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// it will be up to the other response recorder to send the headers,
|
||||||
|
// so it is out of our hands now.
|
||||||
|
if cc.caughtFilteredCode {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
utils.CopyHeaders(cc.responseWriter.Header(), cc.Header())
|
||||||
|
cc.responseWriter.WriteHeader(cc.code)
|
||||||
|
cc.headersSent = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack hijacks the connection
|
||||||
|
func (cc *codeCatcher) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hj, ok := cc.responseWriter.(http.Hijacker); ok {
|
||||||
|
return hj.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", cc.responseWriter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush sends any buffered data to the client.
|
||||||
|
func (cc *codeCatcher) Flush() {
|
||||||
|
// If WriteHeader was already called from the caller, this is a NOOP.
|
||||||
|
// Otherwise, cc.code is actually a 200 here.
|
||||||
|
cc.WriteHeader(cc.code)
|
||||||
|
|
||||||
|
if flusher, ok := cc.responseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type responseRecorder interface {
|
type responseRecorder interface {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
http.Flusher
|
http.Flusher
|
||||||
|
|
|
@ -33,6 +33,30 @@ func TestHandler(t *testing.T) {
|
||||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "no error, but not a 200",
|
||||||
|
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||||
|
backendCode: http.StatusPartialContent,
|
||||||
|
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "My error page.")
|
||||||
|
}),
|
||||||
|
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusPartialContent, recorder.Code, "HTTP status")
|
||||||
|
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusPartialContent))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "a 304, so no Write called",
|
||||||
|
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||||
|
backendCode: http.StatusNotModified,
|
||||||
|
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "whatever, should not be called")
|
||||||
|
}),
|
||||||
|
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusNotModified, recorder.Code, "HTTP status")
|
||||||
|
assert.Contains(t, recorder.Body.String(), "")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
desc: "in the range",
|
desc: "in the range",
|
||||||
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
errorPage: &dynamic.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||||
|
@ -104,6 +128,9 @@ func TestHandler(t *testing.T) {
|
||||||
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(test.backendCode)
|
w.WriteHeader(test.backendCode)
|
||||||
|
if test.backendCode == http.StatusNotModified {
|
||||||
|
return
|
||||||
|
}
|
||||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||||
})
|
})
|
||||||
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
|
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
|
||||||
|
|
Loading…
Reference in a new issue