Panic on aborted requests to properly close the connection
This commit is contained in:
parent
edc0a52b5a
commit
27948493aa
3 changed files with 143 additions and 16 deletions
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
ptypes "github.com/traefik/paerser/types"
|
ptypes "github.com/traefik/paerser/types"
|
||||||
"github.com/traefik/traefik/v2/pkg/middlewares/capture"
|
"github.com/traefik/traefik/v2/pkg/middlewares/capture"
|
||||||
"github.com/traefik/traefik/v2/pkg/middlewares/recovery"
|
|
||||||
"github.com/traefik/traefik/v2/pkg/types"
|
"github.com/traefik/traefik/v2/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -948,8 +947,14 @@ func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) {
|
||||||
req = req.WithContext(reqContext)
|
req = req.WithContext(reqContext)
|
||||||
|
|
||||||
chain := alice.New()
|
chain := alice.New()
|
||||||
|
|
||||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||||
return recovery.New(context.Background(), next)
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
_ = recover() // ignore the stream backend panic to avoid the test to fail.
|
||||||
|
}()
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
|
}), nil
|
||||||
})
|
})
|
||||||
chain = chain.Append(capture.Wrap)
|
chain = chain.Append(capture.Wrap)
|
||||||
chain = chain.Append(WrapHandler(logger))
|
chain = chain.Append(WrapHandler(logger))
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
package recovery
|
package recovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
@ -28,12 +31,16 @@ func New(ctx context.Context, next http.Handler) (http.Handler, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
defer recoverFunc(rw, req)
|
recoveryRW := newRecoveryResponseWriter(rw)
|
||||||
re.next.ServeHTTP(rw, req)
|
defer recoverFunc(recoveryRW, req)
|
||||||
|
|
||||||
|
re.next.ServeHTTP(recoveryRW, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func recoverFunc(rw http.ResponseWriter, r *http.Request) {
|
func recoverFunc(rw recoveryResponseWriter, r *http.Request) {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
|
defer rw.finalizeResponse()
|
||||||
|
|
||||||
logger := log.FromContext(middlewares.GetLoggerCtx(r.Context(), middlewareName, typeName))
|
logger := log.FromContext(middlewares.GetLoggerCtx(r.Context(), middlewareName, typeName))
|
||||||
if !shouldLogPanic(err) {
|
if !shouldLogPanic(err) {
|
||||||
logger.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
|
logger.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
|
||||||
|
@ -45,8 +52,6 @@ func recoverFunc(rw http.ResponseWriter, r *http.Request) {
|
||||||
buf := make([]byte, size)
|
buf := make([]byte, size)
|
||||||
buf = buf[:runtime.Stack(buf, false)]
|
buf = buf[:runtime.Stack(buf, false)]
|
||||||
logger.Errorf("Stack: %s", buf)
|
logger.Errorf("Stack: %s", buf)
|
||||||
|
|
||||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,3 +61,81 @@ func shouldLogPanic(panicValue interface{}) bool {
|
||||||
//nolint:errorlint // false-positive because panicValue is an interface.
|
//nolint:errorlint // false-positive because panicValue is an interface.
|
||||||
return panicValue != nil && panicValue != http.ErrAbortHandler
|
return panicValue != nil && panicValue != http.ErrAbortHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type recoveryResponseWriter interface {
|
||||||
|
http.ResponseWriter
|
||||||
|
|
||||||
|
finalizeResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRecoveryResponseWriter(rw http.ResponseWriter) recoveryResponseWriter {
|
||||||
|
wrapper := &responseWriterWrapper{rw: rw}
|
||||||
|
if _, ok := rw.(http.CloseNotifier); !ok {
|
||||||
|
return wrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
return &responseWriterWrapperWithCloseNotify{wrapper}
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseWriterWrapper struct {
|
||||||
|
rw http.ResponseWriter
|
||||||
|
headersSent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapper) Header() http.Header {
|
||||||
|
return r.rw.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapper) Write(bytes []byte) (int, error) {
|
||||||
|
r.headersSent = true
|
||||||
|
return r.rw.Write(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapper) WriteHeader(code int) {
|
||||||
|
if r.headersSent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handling informational headers.
|
||||||
|
if code >= 100 && code <= 199 {
|
||||||
|
r.rw.WriteHeader(code)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.headersSent = true
|
||||||
|
r.rw.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapper) Flush() {
|
||||||
|
if f, ok := r.rw.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if h, ok := r.rw.(http.Hijacker); ok {
|
||||||
|
return h.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapper) finalizeResponse() {
|
||||||
|
// If headers have been sent this is not possible to respond with an HTTP error,
|
||||||
|
// and we let the server abort the response silently thanks to the http.ErrAbortHandler sentinel panic value.
|
||||||
|
if r.headersSent {
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The response has not yet started to be written,
|
||||||
|
// we can safely return a fresh new error response.
|
||||||
|
http.Error(r.rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseWriterWrapperWithCloseNotify struct {
|
||||||
|
*responseWriterWrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseWriterWrapperWithCloseNotify) CloseNotify() <-chan bool {
|
||||||
|
return r.rw.(http.CloseNotifier).CloseNotify()
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@ package recovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -11,17 +13,54 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRecoverHandler(t *testing.T) {
|
func TestRecoverHandler(t *testing.T) {
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
tests := []struct {
|
||||||
panic("I love panicking!")
|
desc string
|
||||||
|
panicErr error
|
||||||
|
headersSent bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "headers sent and custom panic error",
|
||||||
|
panicErr: errors.New("foo"),
|
||||||
|
headersSent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "headers sent and error abort handler",
|
||||||
|
panicErr: http.ErrAbortHandler,
|
||||||
|
headersSent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "custom panic error",
|
||||||
|
panicErr: errors.New("foo"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "error abort handler",
|
||||||
|
panicErr: http.ErrAbortHandler,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fn := func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if test.headersSent {
|
||||||
|
rw.WriteHeader(http.StatusTeapot)
|
||||||
|
}
|
||||||
|
panic(test.panicErr)
|
||||||
}
|
}
|
||||||
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := httptest.NewServer(recovery)
|
server := httptest.NewServer(recovery)
|
||||||
defer server.Close()
|
t.Cleanup(server.Close)
|
||||||
|
|
||||||
resp, err := http.Get(server.URL)
|
res, err := http.Get(server.URL)
|
||||||
|
if test.headersSent {
|
||||||
|
require.Nil(t, res)
|
||||||
|
assert.ErrorIs(t, err, io.EOF)
|
||||||
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
||||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue