From 1c80f12bc289a99d0f4118103f0cabbd946baa00 Mon Sep 17 00:00:00 2001 From: davefu113 <142489013+davefu113@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:56:04 +0800 Subject: [PATCH] Apply keepalive config to h2c entrypoints --- pkg/server/server_entrypoint_tcp.go | 10 ++--- pkg/server/server_entrypoint_tcp_test.go | 52 ++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index 5cd6cbb49..5d170685f 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -578,17 +578,17 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati handler = http.AllowQuerySemicolons(handler) } + debugConnection := os.Getenv(debugConnectionEnv) != "" + if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) { + handler = newKeepAliveMiddleware(handler, configuration.Transport.KeepAliveMaxRequests, configuration.Transport.KeepAliveMaxTime) + } + if withH2c { handler = h2c.NewHandler(handler, &http2.Server{ MaxConcurrentStreams: uint32(configuration.HTTP2.MaxConcurrentStreams), }) } - debugConnection := os.Getenv(debugConnectionEnv) != "" - if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) { - handler = newKeepAliveMiddleware(handler, configuration.Transport.KeepAliveMaxRequests, configuration.Transport.KeepAliveMaxTime) - } - serverHTTP := &http.Server{ Handler: handler, ErrorLog: httpServerLogger, diff --git a/pkg/server/server_entrypoint_tcp_test.go b/pkg/server/server_entrypoint_tcp_test.go index 4dc9ee428..f3b8865c9 100644 --- a/pkg/server/server_entrypoint_tcp_test.go +++ b/pkg/server/server_entrypoint_tcp_test.go @@ -3,6 +3,7 @@ package server import ( "bufio" "context" + "crypto/tls" "errors" "io" "net" @@ -17,6 +18,7 @@ import ( "github.com/traefik/traefik/v2/pkg/config/static" tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" "github.com/traefik/traefik/v2/pkg/tcp" + "golang.org/x/net/http2" ) func TestShutdownHijacked(t *testing.T) { @@ -330,3 +332,53 @@ func TestKeepAliveMaxTime(t *testing.T) { err = resp.Body.Close() require.NoError(t, err) } + +func TestKeepAliveH2c(t *testing.T) { + epConfig := &static.EntryPointsTransport{} + epConfig.SetDefaults() + epConfig.KeepAliveMaxRequests = 1 + + entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{ + Address: ":0", + Transport: epConfig, + ForwardedHeaders: &static.ForwardedHeaders{}, + HTTP2: &static.HTTP2Config{}, + }, nil) + require.NoError(t, err) + + router, err := tcprouter.NewRouter() + require.NoError(t, err) + + router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + })) + + conn, err := startEntrypoint(entryPoint, router) + require.NoError(t, err) + + http2Transport := &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return conn, nil + }, + } + + client := &http.Client{Transport: http2Transport} + + resp, err := client.Get("http://" + entryPoint.listener.Addr().String()) + require.NoError(t, err) + require.False(t, resp.Close) + err = resp.Body.Close() + require.NoError(t, err) + + _, err = client.Get("http://" + entryPoint.listener.Addr().String()) + require.Error(t, err) + // Unlike HTTP/1, where we can directly check `resp.Close`, HTTP/2 uses a different + // mechanism: it sends a GOAWAY frame when the connection is closing. + // We can only check the error type. The error received should be poll.ErrClosed from + // the `internal/poll` package, but we cannot directly reference the error type due to + // package restrictions. Since this error message ("use of closed network connection") + // is distinct and specific, we rely on its consistency, assuming it is stable and unlikely + // to change. + require.Contains(t, err.Error(), "use of closed network connection") +}