From 27948493aa047a5fdc2eb59c438ce3bc37cb756c Mon Sep 17 00:00:00 2001 From: Anton Bartsits <8313309+tonybart1337@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:44:04 +0200 Subject: [PATCH] Panic on aborted requests to properly close the connection --- pkg/middlewares/accesslog/logger_test.go | 9 ++- pkg/middlewares/recovery/recovery.go | 93 +++++++++++++++++++++-- pkg/middlewares/recovery/recovery_test.go | 57 +++++++++++--- 3 files changed, 143 insertions(+), 16 deletions(-) diff --git a/pkg/middlewares/accesslog/logger_test.go b/pkg/middlewares/accesslog/logger_test.go index 01433c5f6..dc48c3593 100644 --- a/pkg/middlewares/accesslog/logger_test.go +++ b/pkg/middlewares/accesslog/logger_test.go @@ -23,7 +23,6 @@ import ( "github.com/stretchr/testify/require" ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/middlewares/capture" - "github.com/traefik/traefik/v2/pkg/middlewares/recovery" "github.com/traefik/traefik/v2/pkg/types" ) @@ -948,8 +947,14 @@ func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) { req = req.WithContext(reqContext) chain := alice.New() + chain = chain.Append(func(next http.Handler) (http.Handler, error) { - return recovery.New(context.Background(), next) + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + defer func() { + _ = recover() // ignore the stream backend panic to avoid the test to fail. + }() + next.ServeHTTP(rw, req) + }), nil }) chain = chain.Append(capture.Wrap) chain = chain.Append(WrapHandler(logger)) diff --git a/pkg/middlewares/recovery/recovery.go b/pkg/middlewares/recovery/recovery.go index 753a5801b..2415c3ea3 100644 --- a/pkg/middlewares/recovery/recovery.go +++ b/pkg/middlewares/recovery/recovery.go @@ -1,7 +1,10 @@ package recovery import ( + "bufio" "context" + "fmt" + "net" "net/http" "runtime" @@ -28,12 +31,16 @@ func New(ctx context.Context, next http.Handler) (http.Handler, error) { } func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - defer recoverFunc(rw, req) - re.next.ServeHTTP(rw, req) + recoveryRW := newRecoveryResponseWriter(rw) + defer recoverFunc(recoveryRW, req) + + re.next.ServeHTTP(recoveryRW, req) } -func recoverFunc(rw http.ResponseWriter, r *http.Request) { +func recoverFunc(rw recoveryResponseWriter, r *http.Request) { if err := recover(); err != nil { + defer rw.finalizeResponse() + logger := log.FromContext(middlewares.GetLoggerCtx(r.Context(), middlewareName, typeName)) if !shouldLogPanic(err) { logger.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err) @@ -45,8 +52,6 @@ func recoverFunc(rw http.ResponseWriter, r *http.Request) { buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] logger.Errorf("Stack: %s", buf) - - http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } @@ -56,3 +61,81 @@ func shouldLogPanic(panicValue interface{}) bool { //nolint:errorlint // false-positive because panicValue is an interface. return panicValue != nil && panicValue != http.ErrAbortHandler } + +type recoveryResponseWriter interface { + http.ResponseWriter + + finalizeResponse() +} + +func newRecoveryResponseWriter(rw http.ResponseWriter) recoveryResponseWriter { + wrapper := &responseWriterWrapper{rw: rw} + if _, ok := rw.(http.CloseNotifier); !ok { + return wrapper + } + + return &responseWriterWrapperWithCloseNotify{wrapper} +} + +type responseWriterWrapper struct { + rw http.ResponseWriter + headersSent bool +} + +func (r *responseWriterWrapper) Header() http.Header { + return r.rw.Header() +} + +func (r *responseWriterWrapper) Write(bytes []byte) (int, error) { + r.headersSent = true + return r.rw.Write(bytes) +} + +func (r *responseWriterWrapper) WriteHeader(code int) { + if r.headersSent { + return + } + + // Handling informational headers. + if code >= 100 && code <= 199 { + r.rw.WriteHeader(code) + return + } + + r.headersSent = true + r.rw.WriteHeader(code) +} + +func (r *responseWriterWrapper) Flush() { + if f, ok := r.rw.(http.Flusher); ok { + f.Flush() + } +} + +func (r *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := r.rw.(http.Hijacker); ok { + return h.Hijack() + } + + return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw) +} + +func (r *responseWriterWrapper) finalizeResponse() { + // If headers have been sent this is not possible to respond with an HTTP error, + // and we let the server abort the response silently thanks to the http.ErrAbortHandler sentinel panic value. + if r.headersSent { + panic(http.ErrAbortHandler) + } + + // The response has not yet started to be written, + // we can safely return a fresh new error response. + http.Error(r.rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) +} + +type responseWriterWrapperWithCloseNotify struct { + *responseWriterWrapper +} + +func (r *responseWriterWrapperWithCloseNotify) CloseNotify() <-chan bool { + return r.rw.(http.CloseNotifier).CloseNotify() +} diff --git a/pkg/middlewares/recovery/recovery_test.go b/pkg/middlewares/recovery/recovery_test.go index 570717ebc..1929f0b54 100644 --- a/pkg/middlewares/recovery/recovery_test.go +++ b/pkg/middlewares/recovery/recovery_test.go @@ -2,6 +2,8 @@ package recovery import ( "context" + "errors" + "io" "net/http" "net/http/httptest" "testing" @@ -11,17 +13,54 @@ import ( ) func TestRecoverHandler(t *testing.T) { - fn := func(w http.ResponseWriter, r *http.Request) { - panic("I love panicking!") + tests := []struct { + desc string + panicErr error + headersSent bool + }{ + { + desc: "headers sent and custom panic error", + panicErr: errors.New("foo"), + headersSent: true, + }, + { + desc: "headers sent and error abort handler", + panicErr: http.ErrAbortHandler, + headersSent: true, + }, + { + desc: "custom panic error", + panicErr: errors.New("foo"), + }, + { + desc: "error abort handler", + panicErr: http.ErrAbortHandler, + }, } - recovery, err := New(context.Background(), http.HandlerFunc(fn)) - require.NoError(t, err) + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - server := httptest.NewServer(recovery) - defer server.Close() + fn := func(rw http.ResponseWriter, req *http.Request) { + if test.headersSent { + rw.WriteHeader(http.StatusTeapot) + } + panic(test.panicErr) + } + recovery, err := New(context.Background(), http.HandlerFunc(fn)) + require.NoError(t, err) - resp, err := http.Get(server.URL) - require.NoError(t, err) + server := httptest.NewServer(recovery) + t.Cleanup(server.Close) - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + res, err := http.Get(server.URL) + if test.headersSent { + require.Nil(t, res) + assert.ErrorIs(t, err, io.EOF) + } else { + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + } + }) + } }