diff --git a/CHANGELOG.md b/CHANGELOG.md index b13eecc78..2ad7f1f04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +## [v2.11.13](https://github.com/traefik/traefik/tree/v2.11.13) (2024-10-28) +[All Commits](https://github.com/traefik/traefik/compare/v2.11.12...v2.11.13) + +**Bug fixes:** +- **[middleware,service]** Panic on aborted requests to properly close the connection ([#11129](https://github.com/traefik/traefik/pull/11129) by [tonybart1337](https://github.com/tonybart1337)) + +**Documentation:** +- Update business callouts ([#11217](https://github.com/traefik/traefik/pull/11217) by [tomatokoolaid](https://github.com/tomatokoolaid)) + ## [v3.1.6](https://github.com/traefik/traefik/tree/v3.1.6) (2024-10-09) [All Commits](https://github.com/traefik/traefik/compare/v3.1.5...v3.1.6) diff --git a/pkg/middlewares/accesslog/logger_test.go b/pkg/middlewares/accesslog/logger_test.go index ba89e344c..d00f159f8 100644 --- a/pkg/middlewares/accesslog/logger_test.go +++ b/pkg/middlewares/accesslog/logger_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/require" ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v3/pkg/middlewares/capture" - "github.com/traefik/traefik/v3/pkg/middlewares/recovery" "github.com/traefik/traefik/v3/pkg/types" ) @@ -954,8 +953,14 @@ func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) { req = req.WithContext(reqContext) chain := alice.New() + chain = chain.Append(func(next http.Handler) (http.Handler, error) { - return recovery.New(context.Background(), next) + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + defer func() { + _ = recover() // ignore the stream backend panic to avoid the test to fail. + }() + next.ServeHTTP(rw, req) + }), nil }) chain = chain.Append(capture.Wrap) chain = chain.Append(WrapHandler(logger)) diff --git a/pkg/middlewares/recovery/recovery.go b/pkg/middlewares/recovery/recovery.go index ebd3edcb5..1d825a7ad 100644 --- a/pkg/middlewares/recovery/recovery.go +++ b/pkg/middlewares/recovery/recovery.go @@ -1,7 +1,10 @@ package recovery import ( + "bufio" "context" + "fmt" + "net" "net/http" "runtime" @@ -27,12 +30,16 @@ func New(ctx context.Context, next http.Handler) (http.Handler, error) { } func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - defer recoverFunc(rw, req) - re.next.ServeHTTP(rw, req) + recoveryRW := newRecoveryResponseWriter(rw) + defer recoverFunc(recoveryRW, req) + + re.next.ServeHTTP(recoveryRW, req) } -func recoverFunc(rw http.ResponseWriter, req *http.Request) { +func recoverFunc(rw recoveryResponseWriter, req *http.Request) { if err := recover(); err != nil { + defer rw.finalizeResponse() + logger := middlewares.GetLogger(req.Context(), middlewareName, typeName) if !shouldLogPanic(err) { logger.Debug().Msgf("Request has been aborted [%s - %s]: %v", req.RemoteAddr, req.URL, err) @@ -44,8 +51,6 @@ func recoverFunc(rw http.ResponseWriter, req *http.Request) { buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] logger.Error().Msgf("Stack: %s", buf) - - http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } @@ -55,3 +60,81 @@ func shouldLogPanic(panicValue interface{}) bool { //nolint:errorlint // false-positive because panicValue is an interface. return panicValue != nil && panicValue != http.ErrAbortHandler } + +type recoveryResponseWriter interface { + http.ResponseWriter + + finalizeResponse() +} + +func newRecoveryResponseWriter(rw http.ResponseWriter) recoveryResponseWriter { + wrapper := &responseWriterWrapper{rw: rw} + if _, ok := rw.(http.CloseNotifier); !ok { + return wrapper + } + + return &responseWriterWrapperWithCloseNotify{wrapper} +} + +type responseWriterWrapper struct { + rw http.ResponseWriter + headersSent bool +} + +func (r *responseWriterWrapper) Header() http.Header { + return r.rw.Header() +} + +func (r *responseWriterWrapper) Write(bytes []byte) (int, error) { + r.headersSent = true + return r.rw.Write(bytes) +} + +func (r *responseWriterWrapper) WriteHeader(code int) { + if r.headersSent { + return + } + + // Handling informational headers. + if code >= 100 && code <= 199 { + r.rw.WriteHeader(code) + return + } + + r.headersSent = true + r.rw.WriteHeader(code) +} + +func (r *responseWriterWrapper) Flush() { + if f, ok := r.rw.(http.Flusher); ok { + f.Flush() + } +} + +func (r *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := r.rw.(http.Hijacker); ok { + return h.Hijack() + } + + return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw) +} + +func (r *responseWriterWrapper) finalizeResponse() { + // If headers have been sent this is not possible to respond with an HTTP error, + // and we let the server abort the response silently thanks to the http.ErrAbortHandler sentinel panic value. + if r.headersSent { + panic(http.ErrAbortHandler) + } + + // The response has not yet started to be written, + // we can safely return a fresh new error response. + http.Error(r.rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) +} + +type responseWriterWrapperWithCloseNotify struct { + *responseWriterWrapper +} + +func (r *responseWriterWrapperWithCloseNotify) CloseNotify() <-chan bool { + return r.rw.(http.CloseNotifier).CloseNotify() +} diff --git a/pkg/middlewares/recovery/recovery_test.go b/pkg/middlewares/recovery/recovery_test.go index 570717ebc..1929f0b54 100644 --- a/pkg/middlewares/recovery/recovery_test.go +++ b/pkg/middlewares/recovery/recovery_test.go @@ -2,6 +2,8 @@ package recovery import ( "context" + "errors" + "io" "net/http" "net/http/httptest" "testing" @@ -11,17 +13,54 @@ import ( ) func TestRecoverHandler(t *testing.T) { - fn := func(w http.ResponseWriter, r *http.Request) { - panic("I love panicking!") + tests := []struct { + desc string + panicErr error + headersSent bool + }{ + { + desc: "headers sent and custom panic error", + panicErr: errors.New("foo"), + headersSent: true, + }, + { + desc: "headers sent and error abort handler", + panicErr: http.ErrAbortHandler, + headersSent: true, + }, + { + desc: "custom panic error", + panicErr: errors.New("foo"), + }, + { + desc: "error abort handler", + panicErr: http.ErrAbortHandler, + }, } - recovery, err := New(context.Background(), http.HandlerFunc(fn)) - require.NoError(t, err) + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - server := httptest.NewServer(recovery) - defer server.Close() + fn := func(rw http.ResponseWriter, req *http.Request) { + if test.headersSent { + rw.WriteHeader(http.StatusTeapot) + } + panic(test.panicErr) + } + recovery, err := New(context.Background(), http.HandlerFunc(fn)) + require.NoError(t, err) - resp, err := http.Get(server.URL) - require.NoError(t, err) + server := httptest.NewServer(recovery) + t.Cleanup(server.Close) - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + res, err := http.Get(server.URL) + if test.headersSent { + require.Nil(t, res) + assert.ErrorIs(t, err, io.EOF) + } else { + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + } + }) + } }