From f8a78b3b25f601509d0a83dd48fcd0e4b8192d24 Mon Sep 17 00:00:00 2001 From: Kevin Pollet Date: Thu, 26 Sep 2024 11:00:05 +0200 Subject: [PATCH] Introduce a fast proxy mode to improve HTTP/1.1 performances with backends Co-authored-by: Romain Co-authored-by: Julien Salleyron --- .golangci.yml | 2 +- cmd/traefik/traefik.go | 15 +- .../reference/static-configuration/cli-ref.md | 6 + .../reference/static-configuration/env-ref.md | 6 + .../reference/static-configuration/file.toml | 2 + .../reference/static-configuration/file.yaml | 2 + docs/content/user-guides/fastproxy.md | 41 ++ docs/mkdocs.yml | 1 + go.mod | 8 +- go.sum | 7 +- integration/fixtures/simple_fastproxy.toml | 35 + integration/simple_test.go | 26 + pkg/config/static/experimental.go | 7 + pkg/proxy/fast/builder.go | 129 ++++ pkg/proxy/fast/connpool.go | 163 ++++ pkg/proxy/fast/connpool_test.go | 184 +++++ pkg/proxy/fast/dialer.go | 195 +++++ pkg/proxy/fast/proxy.go | 553 ++++++++++++++ pkg/proxy/fast/proxy_test.go | 311 ++++++++ pkg/proxy/fast/proxy_websocket_test.go | 693 ++++++++++++++++++ pkg/proxy/fast/upgrade.go | 104 +++ .../service => proxy/httputil}/bufferpool.go | 26 +- pkg/proxy/httputil/builder.go | 54 ++ pkg/proxy/httputil/builder_test.go | 56 ++ .../httputil/observability.go} | 18 +- .../httputil/observability_test.go} | 2 +- .../service => proxy/httputil}/proxy.go | 12 +- .../httputil}/proxy_websocket_test.go | 199 +++-- pkg/proxy/smart_builder.go | 61 ++ pkg/proxy/smart_builder_test.go | 113 +++ pkg/server/router/router_test.go | 71 +- pkg/server/routerfactory_test.go | 33 +- pkg/server/service/managerfactory.go | 16 +- pkg/server/service/proxy_test.go | 37 - pkg/server/service/service.go | 59 +- pkg/server/service/service_test.go | 60 +- pkg/server/service/smart_roundtripper.go | 9 + .../service/{roundtripper.go => transport.go} | 205 ++++-- ...roundtripper_test.go => transport_test.go} | 30 +- 39 files changed, 3173 insertions(+), 378 deletions(-) create mode 100644 docs/content/user-guides/fastproxy.md create mode 100644 integration/fixtures/simple_fastproxy.toml create mode 100644 pkg/proxy/fast/builder.go create mode 100644 pkg/proxy/fast/connpool.go create mode 100644 pkg/proxy/fast/connpool_test.go create mode 100644 pkg/proxy/fast/dialer.go create mode 100644 pkg/proxy/fast/proxy.go create mode 100644 pkg/proxy/fast/proxy_test.go create mode 100644 pkg/proxy/fast/proxy_websocket_test.go create mode 100644 pkg/proxy/fast/upgrade.go rename pkg/{server/service => proxy/httputil}/bufferpool.go (57%) create mode 100644 pkg/proxy/httputil/builder.go create mode 100644 pkg/proxy/httputil/builder_test.go rename pkg/{server/service/observability_roundtripper.go => proxy/httputil/observability.go} (98%) rename pkg/{server/service/observability_roundtripper_test.go => proxy/httputil/observability_test.go} (99%) rename pkg/{server/service => proxy/httputil}/proxy.go (90%) rename pkg/{server/service => proxy/httputil}/proxy_websocket_test.go (86%) create mode 100644 pkg/proxy/smart_builder.go create mode 100644 pkg/proxy/smart_builder_test.go delete mode 100644 pkg/server/service/proxy_test.go rename pkg/server/service/{roundtripper.go => transport.go} (65%) rename pkg/server/service/{roundtripper_test.go => transport_test.go} (96%) diff --git a/.golangci.yml b/.golangci.yml index c3766eeed..f790c5150 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -229,7 +229,7 @@ issues: text: 'struct-tag: unknown option ''inline'' in JSON tag' linters: - revive - - path: pkg/server/service/bufferpool.go + - path: pkg/proxy/httputil/bufferpool.go text: 'SA6002: argument should be pointer-like to avoid allocations' - path: pkg/server/middleware/middlewares.go text: "Function 'buildConstructor' has too many statements" diff --git a/cmd/traefik/traefik.go b/cmd/traefik/traefik.go index 6b41ef2e2..78edf38cb 100644 --- a/cmd/traefik/traefik.go +++ b/cmd/traefik/traefik.go @@ -37,6 +37,8 @@ import ( "github.com/traefik/traefik/v3/pkg/provider/aggregator" "github.com/traefik/traefik/v3/pkg/provider/tailscale" "github.com/traefik/traefik/v3/pkg/provider/traefik" + "github.com/traefik/traefik/v3/pkg/proxy" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" "github.com/traefik/traefik/v3/pkg/safe" "github.com/traefik/traefik/v3/pkg/server" "github.com/traefik/traefik/v3/pkg/server/middleware" @@ -281,10 +283,16 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err log.Info().Msg("Successfully obtained SPIFFE SVID.") } - roundTripperManager := service.NewRoundTripperManager(spiffeX509Source) + transportManager := service.NewTransportManager(spiffeX509Source) + + var proxyBuilder service.ProxyBuilder = httputil.NewProxyBuilder(transportManager, semConvMetricRegistry) + if staticConfiguration.Experimental != nil && staticConfiguration.Experimental.FastProxy != nil { + proxyBuilder = proxy.NewSmartBuilder(transportManager, proxyBuilder, *staticConfiguration.Experimental.FastProxy) + } + dialerManager := tcp.NewDialerManager(spiffeX509Source) acmeHTTPHandler := getHTTPChallengeHandler(acmeProviders, httpChallengeProvider) - managerFactory := service.NewManagerFactory(*staticConfiguration, routinesPool, observabilityMgr, roundTripperManager, acmeHTTPHandler) + managerFactory := service.NewManagerFactory(*staticConfiguration, routinesPool, observabilityMgr, transportManager, proxyBuilder, acmeHTTPHandler) // Router factory @@ -318,7 +326,8 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err // Server Transports watcher.AddListener(func(conf dynamic.Configuration) { - roundTripperManager.Update(conf.HTTP.ServersTransports) + transportManager.Update(conf.HTTP.ServersTransports) + proxyBuilder.Update(conf.HTTP.ServersTransports) dialerManager.Update(conf.TCP.ServersTransports) }) diff --git a/docs/content/reference/static-configuration/cli-ref.md b/docs/content/reference/static-configuration/cli-ref.md index fe1c408e8..3d6e02908 100644 --- a/docs/content/reference/static-configuration/cli-ref.md +++ b/docs/content/reference/static-configuration/cli-ref.md @@ -228,6 +228,12 @@ WriteTimeout is the maximum duration before timing out writes of the response. I `--entrypoints..udp.timeout`: Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```) +`--experimental.fastproxy`: +Enable the FastProxy implementation. (Default: ```false```) + +`--experimental.fastproxy.debug`: +Enable debug mode for the FastProxy implementation. (Default: ```false```) + `--experimental.kubernetesgateway`: (Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```) diff --git a/docs/content/reference/static-configuration/env-ref.md b/docs/content/reference/static-configuration/env-ref.md index ef6656ddf..058190eaf 100644 --- a/docs/content/reference/static-configuration/env-ref.md +++ b/docs/content/reference/static-configuration/env-ref.md @@ -228,6 +228,12 @@ WriteTimeout is the maximum duration before timing out writes of the response. I `TRAEFIK_ENTRYPOINTS__UDP_TIMEOUT`: Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```) +`TRAEFIK_EXPERIMENTAL_FASTPROXY`: +Enable the FastProxy implementation. (Default: ```false```) + +`TRAEFIK_EXPERIMENTAL_FASTPROXY_DEBUG`: +Enable debug mode for the FastProxy implementation. (Default: ```false```) + `TRAEFIK_EXPERIMENTAL_KUBERNETESGATEWAY`: (Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```) diff --git a/docs/content/reference/static-configuration/file.toml b/docs/content/reference/static-configuration/file.toml index 392088ac4..dd13bbc0a 100644 --- a/docs/content/reference/static-configuration/file.toml +++ b/docs/content/reference/static-configuration/file.toml @@ -509,6 +509,8 @@ [experimental.localPlugins.LocalDescriptor1.settings] envs = ["foobar", "foobar"] mounts = ["foobar", "foobar"] + [experimental.fastProxy] + debug = true [core] defaultRuleSyntax = "foobar" diff --git a/docs/content/reference/static-configuration/file.yaml b/docs/content/reference/static-configuration/file.yaml index 6c79ac9c5..89ee83476 100644 --- a/docs/content/reference/static-configuration/file.yaml +++ b/docs/content/reference/static-configuration/file.yaml @@ -572,6 +572,8 @@ experimental: mounts: - foobar - foobar + fastProxy: + debug: true kubernetesGateway: true core: defaultRuleSyntax: foobar diff --git a/docs/content/user-guides/fastproxy.md b/docs/content/user-guides/fastproxy.md new file mode 100644 index 000000000..3294f9dc4 --- /dev/null +++ b/docs/content/user-guides/fastproxy.md @@ -0,0 +1,41 @@ +--- +title: "Traefik FastProxy Experimental Configuration" +description: "This section of the Traefik Proxy documentation explains how to use the new FastProxy option." +--- + +# Traefik FastProxy Experimental Configuration + +## Overview + +This guide provides instructions on how to configure and use the new experimental `fastProxy` static configuration option in Traefik. +The `fastProxy` option introduces a high-performance reverse proxy designed to enhance the performance of routing. + +!!! info "Limitations" + + Please note that the new fast proxy implementation does not work with HTTP/2. + This means that when a H2C or HTTPS request with [HTTP2 enabled](../routing/services/index.md#disablehttp2) is sent to a backend, the fallback proxy is the regular one. + + Additionnaly, observability features like tracing and OTEL semconv metrics are not supported for the moment. + +!!! warning "Experimental" + + The `fastProxy` option is currently experimental and subject to change in future releases. + Use with caution in production environments. + +### Enabling FastProxy + +The fastProxy option is a static configuration parameter. +To enable it, you need to configure it in your Traefik static configuration + +```yaml tab="File (YAML)" +experimental: + fastProxy: {} +``` + +```toml tab="File (TOML)" +[experimental.fastProxy] +``` + +```bash tab="CLI" +--experimental.fastProxy +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 2d2dc3858..3e86558cf 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -163,6 +163,7 @@ nav: - 'Overview': 'observability/tracing/overview.md' - 'OpenTelemetry': 'observability/tracing/opentelemetry.md' - 'User Guides': + - 'FastProxy': 'user-guides/fastproxy.md' - 'Kubernetes and Let''s Encrypt': 'user-guides/crd-acme/index.md' - 'gRPC Examples': 'user-guides/grpc.md' - 'Docker': diff --git a/go.mod b/go.mod index 58dfb27d8..0d9e6fe95 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/BurntSushi/toml v1.4.0 github.com/Masterminds/sprig/v3 v3.2.3 github.com/abbot/go-http-auth v0.0.0-00010101000000-000000000000 // No tag on the repo. - github.com/andybalholm/brotli v1.0.6 + github.com/andybalholm/brotli v1.1.0 github.com/aws/aws-sdk-go v1.44.327 github.com/cenkalti/backoff/v4 v4.3.0 github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // No tag on the repo. @@ -102,6 +102,11 @@ require ( sigs.k8s.io/yaml v1.4.0 ) +require ( + github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 + github.com/valyala/fasthttp v1.55.0 +) + require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect dario.cat/mergo v1.0.0 // indirect @@ -315,6 +320,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/transip/gotransip/v6 v6.23.0 // indirect github.com/ultradns/ultradns-go-sdk v1.6.1-20231103022937-8589b6a // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/vinyldns/go-vinyldns v0.9.16 // indirect github.com/vultr/govultr/v3 v3.9.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index fa0a4fb64..4f18c4d30 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/aliyun/alibaba-cloud-sdk-go v1.62.712 h1:lM7JnA9dEdDFH9XOgRNQMDTQnOjlLkDTNA7c0aWTQ30= github.com/aliyun/alibaba-cloud-sdk-go v1.62.712/go.mod h1:SOSDHfe1kX91v3W5QiBsWSLqeLxImobbMX1mxrFHsVQ= -github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= -github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= @@ -1041,7 +1041,10 @@ github.com/unrolled/secure v1.0.9 h1:BWRuEb1vDrBFFDdbCnKkof3gZ35I/bnHGyt0LB0TNyQ github.com/unrolled/secure v1.0.9/go.mod h1:fO+mEan+FLB0CdEnHf6Q4ZZVNqG+5fuLFnP8p0BXDPI= github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8= +github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ= diff --git a/integration/fixtures/simple_fastproxy.toml b/integration/fixtures/simple_fastproxy.toml new file mode 100644 index 000000000..b278c8537 --- /dev/null +++ b/integration/fixtures/simple_fastproxy.toml @@ -0,0 +1,35 @@ +[global] + checkNewVersion = false + sendAnonymousUsage = false + +[log] + level = "DEBUG" + noColor = true + +[entryPoints] + [entryPoints.web] + address = ":8000" + +[api] + insecure = true + +[providers.file] + filename = "{{ .SelfFilename }}" + +[experimental] + [experimental.fastProxy] + debug = true + +## dynamic configuration ## + +[http.routers] + [http.routers.router1] + entrypoints = ["web"] + service = "service1" + rule = "PathPrefix(`/`)" + +[http.services] + [http.services.service1] + [http.services.service1.loadBalancer] + [[http.services.service1.loadBalancer.servers]] + url = "{{ .Server }}" diff --git a/integration/simple_test.go b/integration/simple_test.go index d03acaa64..92c0c5ae5 100644 --- a/integration/simple_test.go +++ b/integration/simple_test.go @@ -65,6 +65,32 @@ func (s *SimpleSuite) TestSimpleDefaultConfig() { require.NoError(s.T(), err) } +func (s *SimpleSuite) TestSimpleFastProxy() { + var callCount int + srv1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Contains(s.T(), req.Header, "X-Traefik-Fast-Proxy") + callCount++ + })) + defer srv1.Close() + + file := s.adaptFile("fixtures/simple_fastproxy.toml", struct { + Server string + }{ + Server: srv1.URL, + }) + + s.traefikCmd(withConfigFile(file), "--log.level=DEBUG") + + // wait for traefik + err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("127.0.0.1")) + require.NoError(s.T(), err) + + err = try.GetRequest("http://127.0.0.1:8000/", time.Second) + require.NoError(s.T(), err) + + assert.GreaterOrEqual(s.T(), 1, callCount) +} + func (s *SimpleSuite) TestWithWebConfig() { s.cmdTraefik(withConfigFile("fixtures/simple_web.toml")) diff --git a/pkg/config/static/experimental.go b/pkg/config/static/experimental.go index ecd18424d..47b2989f8 100644 --- a/pkg/config/static/experimental.go +++ b/pkg/config/static/experimental.go @@ -7,6 +7,13 @@ type Experimental struct { Plugins map[string]plugins.Descriptor `description:"Plugins configuration." json:"plugins,omitempty" toml:"plugins,omitempty" yaml:"plugins,omitempty" export:"true"` LocalPlugins map[string]plugins.LocalDescriptor `description:"Local plugins configuration." json:"localPlugins,omitempty" toml:"localPlugins,omitempty" yaml:"localPlugins,omitempty" export:"true"` + FastProxy *FastProxyConfig `description:"Enable the FastProxy implementation." json:"fastProxy,omitempty" toml:"fastProxy,omitempty" yaml:"fastProxy,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"` + // Deprecated: KubernetesGateway provider is not an experimental feature starting with v3.1. Please remove its usage from the static configuration. KubernetesGateway bool `description:"(Deprecated) Allow the Kubernetes gateway api provider usage." json:"kubernetesGateway,omitempty" toml:"kubernetesGateway,omitempty" yaml:"kubernetesGateway,omitempty" export:"true"` } + +// FastProxyConfig holds the FastProxy configuration. +type FastProxyConfig struct { + Debug bool `description:"Enable debug mode for the FastProxy implementation." json:"debug,omitempty" toml:"debug,omitempty" yaml:"debug,omitempty" export:"true"` +} diff --git a/pkg/proxy/fast/builder.go b/pkg/proxy/fast/builder.go new file mode 100644 index 000000000..f330d6756 --- /dev/null +++ b/pkg/proxy/fast/builder.go @@ -0,0 +1,129 @@ +package fast + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "reflect" + "time" + + "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/config/static" +) + +// TransportManager manages transport used for backend communications. +type TransportManager interface { + Get(name string) (*dynamic.ServersTransport, error) + GetTLSConfig(name string) (*tls.Config, error) +} + +// ProxyBuilder handles the connection pools for the FastProxy proxies. +type ProxyBuilder struct { + debug bool + transportManager TransportManager + + // lock isn't needed because ProxyBuilder is not called concurrently. + pools map[string]map[string]*connPool + proxy func(*http.Request) (*url.URL, error) + + // not goroutine safe. + configs map[string]*dynamic.ServersTransport +} + +// NewProxyBuilder creates a new ProxyBuilder. +func NewProxyBuilder(transportManager TransportManager, config static.FastProxyConfig) *ProxyBuilder { + return &ProxyBuilder{ + debug: config.Debug, + transportManager: transportManager, + pools: make(map[string]map[string]*connPool), + proxy: http.ProxyFromEnvironment, + configs: make(map[string]*dynamic.ServersTransport), + } +} + +// Update updates all the round-tripper corresponding to the given configs. +// This method must not be used concurrently. +func (r *ProxyBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) { + for configName := range r.configs { + if _, ok := newConfigs[configName]; !ok { + for _, c := range r.pools[configName] { + c.Close() + } + delete(r.pools, configName) + } + } + + for newConfigName, newConfig := range newConfigs { + if !reflect.DeepEqual(newConfig, r.configs[newConfigName]) { + for _, c := range r.pools[newConfigName] { + c.Close() + } + delete(r.pools, newConfigName) + } + } + + r.configs = newConfigs +} + +// Build builds a new ReverseProxy with the given configuration. +func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader bool) (http.Handler, error) { + proxyURL, err := r.proxy(&http.Request{URL: targetURL}) + if err != nil { + return nil, fmt.Errorf("getting proxy: %w", err) + } + + cfg, err := r.transportManager.Get(cfgName) + if err != nil { + return nil, fmt.Errorf("getting ServersTransport: %w", err) + } + + var responseHeaderTimeout time.Duration + if cfg.ForwardingTimeouts != nil { + responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout) + } + + tlsConfig, err := r.transportManager.GetTLSConfig(cfgName) + if err != nil { + return nil, fmt.Errorf("getting TLS config: %w", err) + } + + pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL) + return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, responseHeaderTimeout, pool) +} + +func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool { + pool, ok := r.pools[cfgName] + if !ok { + pool = make(map[string]*connPool) + r.pools[cfgName] = pool + } + + if connPool, ok := pool[targetURL.String()]; ok { + return connPool + } + + idleConnTimeout := 90 * time.Second + dialTimeout := 30 * time.Second + if config.ForwardingTimeouts != nil { + idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout) + dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout) + } + + proxyDialer := newDialer(dialerConfig{ + DialKeepAlive: 0, + DialTimeout: dialTimeout, + HTTP: true, + TLS: targetURL.Scheme == "https", + ProxyURL: proxyURL, + }, tlsConfig) + + connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) { + return proxyDialer.Dial("tcp", addrFromURL(targetURL)) + }) + + r.pools[cfgName][targetURL.String()] = connPool + + return connPool +} diff --git a/pkg/proxy/fast/connpool.go b/pkg/proxy/fast/connpool.go new file mode 100644 index 000000000..e0d2c4e7f --- /dev/null +++ b/pkg/proxy/fast/connpool.go @@ -0,0 +1,163 @@ +package fast + +import ( + "fmt" + "net" + "time" + + "github.com/rs/zerolog/log" +) + +// conn is an enriched net.Conn. +type conn struct { + net.Conn + + idleAt time.Time // the last time it was marked as idle. + idleTimeout time.Duration +} + +func (c *conn) isExpired() bool { + expTime := c.idleAt.Add(c.idleTimeout) + return c.idleTimeout > 0 && time.Now().After(expTime) +} + +// connPool is a net.Conn pool implementation using channels. +type connPool struct { + dialer func() (net.Conn, error) + idleConns chan *conn + idleConnTimeout time.Duration + ticker *time.Ticker + doneCh chan struct{} +} + +// newConnPool creates a new connPool. +func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool { + c := &connPool{ + dialer: dialer, + idleConns: make(chan *conn, maxIdleConn), + idleConnTimeout: idleConnTimeout, + doneCh: make(chan struct{}), + } + + if idleConnTimeout > 0 { + c.ticker = time.NewTicker(c.idleConnTimeout / 2) + go func() { + for { + select { + case <-c.ticker.C: + c.cleanIdleConns() + case <-c.doneCh: + return + } + } + }() + } + + return c +} + +// Close closes stop the cleanIdleConn goroutine. +func (c *connPool) Close() { + if c.idleConnTimeout > 0 { + close(c.doneCh) + c.ticker.Stop() + } +} + +// AcquireConn returns an idle net.Conn from the pool. +func (c *connPool) AcquireConn() (*conn, error) { + for { + co, err := c.acquireConn() + if err != nil { + return nil, err + } + + if !co.isExpired() { + return co, nil + } + + // As the acquired conn is expired we can close it + // without putting it again into the pool. + if err := co.Close(); err != nil { + log.Debug(). + Err(err). + Msg("Unexpected error while releasing the connection") + } + } +} + +// ReleaseConn releases the given net.Conn to the pool. +func (c *connPool) ReleaseConn(co *conn) { + co.idleAt = time.Now() + c.releaseConn(co) +} + +// cleanIdleConns is a routine cleaning the expired connections at a regular basis. +func (c *connPool) cleanIdleConns() { + for { + select { + case co := <-c.idleConns: + if !co.isExpired() { + c.releaseConn(co) + return + } + + if err := co.Close(); err != nil { + log.Debug(). + Err(err). + Msg("Unexpected error while releasing the connection") + } + + default: + return + } + } +} + +func (c *connPool) acquireConn() (*conn, error) { + select { + case co := <-c.idleConns: + return co, nil + + default: + errCh := make(chan error, 1) + go c.askForNewConn(errCh) + + select { + case co := <-c.idleConns: + return co, nil + + case err := <-errCh: + return nil, err + } + } +} + +func (c *connPool) releaseConn(co *conn) { + select { + case c.idleConns <- co: + + // Hitting the default case means that we have reached the maximum number of idle + // connections, so we can close it. + default: + if err := co.Close(); err != nil { + log.Debug(). + Err(err). + Msg("Unexpected error while releasing the connection") + } + } +} + +func (c *connPool) askForNewConn(errCh chan<- error) { + co, err := c.dialer() + if err != nil { + errCh <- fmt.Errorf("create conn: %w", err) + return + } + + c.releaseConn(&conn{ + Conn: co, + idleAt: time.Now(), + idleTimeout: c.idleConnTimeout, + }) +} diff --git a/pkg/proxy/fast/connpool_test.go b/pkg/proxy/fast/connpool_test.go new file mode 100644 index 000000000..6ab4e9740 --- /dev/null +++ b/pkg/proxy/fast/connpool_test.go @@ -0,0 +1,184 @@ +package fast + +import ( + "net" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnPool_ConnReuse(t *testing.T) { + testCases := []struct { + desc string + poolFn func(pool *connPool) + expected int + }{ + { + desc: "One connection", + poolFn: func(pool *connPool) { + c1, _ := pool.AcquireConn() + pool.ReleaseConn(c1) + }, + expected: 1, + }, + { + desc: "Two connections with release", + poolFn: func(pool *connPool) { + c1, _ := pool.AcquireConn() + pool.ReleaseConn(c1) + + c2, _ := pool.AcquireConn() + pool.ReleaseConn(c2) + }, + expected: 1, + }, + { + desc: "Two concurrent connections", + poolFn: func(pool *connPool) { + c1, _ := pool.AcquireConn() + c2, _ := pool.AcquireConn() + + pool.ReleaseConn(c1) + pool.ReleaseConn(c2) + }, + expected: 2, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + var connAlloc int + dialer := func() (net.Conn, error) { + connAlloc++ + return &net.TCPConn{}, nil + } + + pool := newConnPool(2, 0, dialer) + test.poolFn(pool) + + assert.Equal(t, test.expected, connAlloc) + }) + } +} + +func TestConnPool_MaxIdleConn(t *testing.T) { + testCases := []struct { + desc string + poolFn func(pool *connPool) + maxIdleConn int + expected int + }{ + { + desc: "One connection", + poolFn: func(pool *connPool) { + c1, _ := pool.AcquireConn() + pool.ReleaseConn(c1) + }, + maxIdleConn: 1, + expected: 1, + }, + { + desc: "Multiple connections with defered release", + poolFn: func(pool *connPool) { + for range 7 { + c, _ := pool.AcquireConn() + defer pool.ReleaseConn(c) + } + }, + maxIdleConn: 5, + expected: 5, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + var keepOpenedConn int + dialer := func() (net.Conn, error) { + keepOpenedConn++ + return &mockConn{closeFn: func() error { + keepOpenedConn-- + return nil + }}, nil + } + + pool := newConnPool(test.maxIdleConn, 0, dialer) + test.poolFn(pool) + + assert.Equal(t, test.expected, keepOpenedConn) + }) + } +} + +func TestGC(t *testing.T) { + var isDestroyed bool + pools := map[string]*connPool{} + dialer := func() (net.Conn, error) { + c := &mockConn{closeFn: func() error { + return nil + }} + return c, nil + } + + pools["test"] = newConnPool(10, 1*time.Second, dialer) + runtime.SetFinalizer(pools["test"], func(p *connPool) { + isDestroyed = true + }) + c, err := pools["test"].AcquireConn() + require.NoError(t, err) + + pools["test"].ReleaseConn(c) + + pools["test"].Close() + + delete(pools, "test") + + runtime.GC() + + require.True(t, isDestroyed) +} + +type mockConn struct { + closeFn func() error +} + +func (m *mockConn) Read(_ []byte) (n int, err error) { + panic("implement me") +} + +func (m *mockConn) Write(_ []byte) (n int, err error) { + panic("implement me") +} + +func (m *mockConn) Close() error { + if m.closeFn != nil { + return m.closeFn() + } + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (m *mockConn) RemoteAddr() net.Addr { + panic("implement me") +} + +func (m *mockConn) SetDeadline(_ time.Time) error { + panic("implement me") +} + +func (m *mockConn) SetReadDeadline(_ time.Time) error { + panic("implement me") +} + +func (m *mockConn) SetWriteDeadline(_ time.Time) error { + panic("implement me") +} diff --git a/pkg/proxy/fast/dialer.go b/pkg/proxy/fast/dialer.go new file mode 100644 index 000000000..266001533 --- /dev/null +++ b/pkg/proxy/fast/dialer.go @@ -0,0 +1,195 @@ +package fast + +import ( + "bufio" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "net" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/net/proxy" +) + +const ( + schemeHTTP = "http" + schemeHTTPS = "https" + schemeSocks5 = "socks5" +) + +type dialer interface { + Dial(network, addr string) (c net.Conn, err error) +} + +type dialerFunc func(network, addr string) (net.Conn, error) + +func (d dialerFunc) Dial(network, addr string) (net.Conn, error) { + return d(network, addr) +} + +type dialerConfig struct { + DialKeepAlive time.Duration + DialTimeout time.Duration + ProxyURL *url.URL + HTTP bool + TLS bool +} + +func newDialer(cfg dialerConfig, tlsConfig *tls.Config) dialer { + if cfg.ProxyURL == nil { + return buildDialer(cfg, tlsConfig, cfg.TLS) + } + + proxyDialer := buildDialer(cfg, tlsConfig, cfg.ProxyURL.Scheme == "https") + proxyAddr := addrFromURL(cfg.ProxyURL) + + switch { + case cfg.ProxyURL.Scheme == schemeSocks5: + var auth *proxy.Auth + if u := cfg.ProxyURL.User; u != nil { + auth = &proxy.Auth{User: u.Username()} + auth.Password, _ = u.Password() + } + + // SOCKS5 implementation do not return errors. + socksDialer, _ := proxy.SOCKS5("tcp", proxyAddr, auth, proxyDialer) + return dialerFunc(func(network, targetAddr string) (net.Conn, error) { + co, err := socksDialer.Dial("tcp", targetAddr) + if err != nil { + return nil, err + } + + if cfg.TLS { + c := &tls.Config{} + if tlsConfig != nil { + c = tlsConfig.Clone() + } + + if c.ServerName == "" { + host, _, _ := net.SplitHostPort(targetAddr) + c.ServerName = host + } + + return tls.Client(co, c), nil + } + + return co, nil + }) + case cfg.HTTP && !cfg.TLS: + // Nothing to do the Proxy-Authorization header will be added by the ReverseProxy. + + default: + hdr := make(http.Header) + if u := cfg.ProxyURL.User; u != nil { + username := u.Username() + password, _ := u.Password() + auth := username + ":" + password + hdr.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + } + + return dialerFunc(func(network, targetAddr string) (net.Conn, error) { + conn, err := proxyDialer.Dial("tcp", proxyAddr) + if err != nil { + return nil, err + } + + connectReq := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Opaque: targetAddr}, + Host: targetAddr, + Header: hdr, + } + + connectCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails + var resp *http.Response + + // Write the CONNECT request & read the response. + go func() { + defer close(didReadResponse) + err = connectReq.Write(conn) + if err != nil { + return + } + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(conn) + resp, err = http.ReadResponse(br, connectReq) + }() + + select { + case <-connectCtx.Done(): + conn.Close() + <-didReadResponse + return nil, connectCtx.Err() + case <-didReadResponse: + // resp or err now set + } + if err != nil { + conn.Close() + return nil, err + } + + if resp.StatusCode != http.StatusOK { + _, statusText, ok := strings.Cut(resp.Status, " ") + conn.Close() + if !ok { + return nil, errors.New("unknown status code") + } + + return nil, errors.New(statusText) + } + + c := &tls.Config{} + if tlsConfig != nil { + c = tlsConfig.Clone() + } + if c.ServerName == "" { + host, _, _ := net.SplitHostPort(targetAddr) + c.ServerName = host + } + + return tls.Client(conn, c), nil + }) + } + return dialerFunc(func(network, addr string) (net.Conn, error) { + return proxyDialer.Dial("tcp", proxyAddr) + }) +} + +func buildDialer(cfg dialerConfig, tlsConfig *tls.Config, isTLS bool) dialer { + dialer := &net.Dialer{ + Timeout: cfg.DialTimeout, + KeepAlive: cfg.DialKeepAlive, + } + + if !isTLS { + return dialer + } + + return &tls.Dialer{ + NetDialer: dialer, + Config: tlsConfig, + } +} + +func addrFromURL(u *url.URL) string { + addr := u.Host + + if u.Port() == "" { + switch u.Scheme { + case schemeHTTP: + return addr + ":80" + case schemeHTTPS: + return addr + ":443" + } + } + + return addr +} diff --git a/pkg/proxy/fast/proxy.go b/pkg/proxy/fast/proxy.go new file mode 100644 index 000000000..e61a32ad8 --- /dev/null +++ b/pkg/proxy/fast/proxy.go @@ -0,0 +1,553 @@ +package fast + +import ( + "bufio" + "bytes" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/http/httputil" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog/log" + proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil" + "github.com/valyala/fasthttp" + "golang.org/x/net/http/httpguts" +) + +const ( + bufferSize = 32 * 1024 + bufioSize = 64 * 1024 +) + +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +type pool[T any] struct { + pool sync.Pool +} + +func (p *pool[T]) Get() T { + if tmp := p.pool.Get(); tmp != nil { + return tmp.(T) + } + + var res T + return res +} + +func (p *pool[T]) Put(x T) { + p.pool.Put(x) +} + +type buffConn struct { + *bufio.Reader + net.Conn +} + +func (b buffConn) Read(p []byte) (int, error) { + return b.Reader.Read(p) +} + +type writeDetector struct { + net.Conn + + written bool +} + +func (w *writeDetector) Write(p []byte) (int, error) { + n, err := w.Conn.Write(p) + if n > 0 { + w.written = true + } + + return n, err +} + +type writeFlusher struct { + io.Writer +} + +func (w *writeFlusher) Write(b []byte) (int, error) { + n, err := w.Writer.Write(b) + if f, ok := w.Writer.(http.Flusher); ok { + f.Flush() + } + + return n, err +} + +type timeoutError struct { + error +} + +func (t timeoutError) Timeout() bool { + return true +} + +func (t timeoutError) Temporary() bool { + return false +} + +// ReverseProxy is the FastProxy reverse proxy implementation. +type ReverseProxy struct { + debug bool + + connPool *connPool + + bufferPool pool[[]byte] + readerPool pool[*bufio.Reader] + writerPool pool[*bufio.Writer] + limitReaderPool pool[*io.LimitedReader] + + proxyAuth string + + targetURL *url.URL + passHostHeader bool + responseHeaderTimeout time.Duration +} + +// NewReverseProxy creates a new ReverseProxy. +func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) { + var proxyAuth string + if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) + } + + return &ReverseProxy{ + debug: debug, + passHostHeader: passHostHeader, + targetURL: targetURL, + proxyAuth: proxyAuth, + connPool: connPool, + responseHeaderTimeout: responseHeaderTimeout, + }, nil +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.Body != nil { + defer req.Body.Close() + } + + outReq := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(outReq) + + // This is not required as the headers are already normalized by net/http. + outReq.Header.DisableNormalizing() + + for k, v := range req.Header { + for _, s := range v { + outReq.Header.Add(k, s) + } + } + + removeConnectionHeaders(&outReq.Header) + + for _, header := range hopHeaders { + outReq.Header.Del(header) + } + + if p.proxyAuth != "" { + outReq.Header.Set("Proxy-Authorization", p.proxyAuth) + } + + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outReq.Header.Set("Te", "trailers") + } + + if p.debug { + outReq.Header.Set("X-Traefik-Fast-Proxy", "enabled") + } + + reqUpType := upgradeType(req.Header) + if !isGraphic(reqUpType) { + proxyhttputil.ErrorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) + return + } + + if reqUpType != "" { + outReq.Header.Set("Connection", "Upgrade") + outReq.Header.Set("Upgrade", reqUpType) + if reqUpType == "websocket" { + cleanWebSocketHeaders(&outReq.Header) + } + } + + u2 := new(url.URL) + *u2 = *req.URL + u2.Scheme = p.targetURL.Scheme + u2.Host = p.targetURL.Host + + u := req.URL + if req.RequestURI != "" { + parsedURL, err := url.ParseRequestURI(req.RequestURI) + if err == nil { + u = parsedURL + } + } + + u2.Path = u.Path + u2.RawPath = u.RawPath + u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&") + + outReq.SetHost(u2.Host) + outReq.Header.SetHost(u2.Host) + + if p.passHostHeader { + outReq.Header.SetHost(req.Host) + } + + outReq.SetRequestURI(u2.RequestURI()) + + outReq.SetBodyStream(req.Body, int(req.ContentLength)) + + outReq.Header.SetMethod(req.Method) + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := req.Header["X-Forwarded-For"] + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + + omit := ok && prior == nil // Go Issue 38079: nil now means don't populate the header + if !omit { + outReq.Header.Set("X-Forwarded-For", clientIP) + } + } + + if err := p.roundTrip(rw, req, outReq, reqUpType); err != nil { + proxyhttputil.ErrorHandler(rw, req, err) + } +} + +// Note that unlike the net/http RoundTrip: +// - we are not supporting "100 Continue" response to forward them as-is to the client. +// - we are not asking for compressed response automatically. That is because this will add an extra cost when the +// client is asking for an uncompressed response, as we will have to un-compress it, and nowadays most clients are +// already asking for compressed response (allowing "passthrough" compression). +func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outReq *fasthttp.Request, reqUpType string) error { + ctx := req.Context() + trace := httptrace.ContextClientTrace(ctx) + + var co *conn + for { + select { + case <-ctx.Done(): + return ctx.Err() + + default: + } + + var err error + co, err = p.connPool.AcquireConn() + if err != nil { + return fmt.Errorf("acquire connection: %w", err) + } + + wd := &writeDetector{Conn: co} + + err = p.writeRequest(wd, outReq) + if wd.written && trace != nil && trace.WroteRequest != nil { + // WroteRequest hook is used by the tracing middleware to detect if the request has been written. + trace.WroteRequest(httptrace.WroteRequestInfo{}) + } + if err == nil { + break + } + + log.Ctx(ctx).Debug().Err(err).Msg("Error while writing request") + + co.Close() + + if wd.written && !isReplayable(req) { + return err + } + } + + br := p.readerPool.Get() + if br == nil { + br = bufio.NewReaderSize(co, bufioSize) + } + defer p.readerPool.Put(br) + + br.Reset(co) + + res := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(res) + + res.Header.SetNoDefaultContentType(true) + + for { + var timer *time.Timer + errTimeout := atomic.Pointer[timeoutError]{} + if p.responseHeaderTimeout > 0 { + timer = time.AfterFunc(p.responseHeaderTimeout, func() { + errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")}) + co.Close() + }) + } + + res.Header.SetNoDefaultContentType(true) + if err := res.Header.Read(br); err != nil { + if p.responseHeaderTimeout > 0 { + if errT := errTimeout.Load(); errT != nil { + return errT + } + } + co.Close() + return err + } + + if timer != nil { + timer.Stop() + } + + fixPragmaCacheControl(&res.Header) + + resCode := res.StatusCode() + is1xx := 100 <= resCode && resCode <= 199 + // treat 101 as a terminal status, see issue 26161 + is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols + if is1xxNonTerminal { + removeConnectionHeaders(&res.Header) + h := rw.Header() + + for _, header := range hopHeaders { + res.Header.Del(header) + } + + res.Header.VisitAll(func(key, value []byte) { + rw.Header().Add(string(key), string(value)) + }) + + rw.WriteHeader(res.StatusCode()) + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + for k := range h { + delete(h, k) + } + + res.Reset() + res.Header.Reset() + res.Header.SetNoDefaultContentType(true) + + continue + } + break + } + + announcedTrailers := res.Header.Peek("Trailer") + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode() == http.StatusSwitchingProtocols { + // As the connection has been hijacked, it cannot be added back to the pool. + handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br}) + return nil + } + + removeConnectionHeaders(&res.Header) + + for _, header := range hopHeaders { + res.Header.Del(header) + } + + if len(announcedTrailers) > 0 { + res.Header.Add("Trailer", string(announcedTrailers)) + } + + res.Header.VisitAll(func(key, value []byte) { + rw.Header().Add(string(key), string(value)) + }) + + rw.WriteHeader(res.StatusCode()) + + // Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received. + if res.Header.ContentLength() == -1 { + cbr := httputil.NewChunkedReader(br) + + b := p.bufferPool.Get() + if b == nil { + b = make([]byte, bufferSize) + } + defer p.bufferPool.Put(b) + + if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil { + co.Close() + return err + } + + res.Header.Reset() + res.Header.SetNoDefaultContentType(true) + if err := res.Header.ReadTrailer(br); err != nil { + co.Close() + return err + } + + if res.Header.Len() > 0 { + var announcedTrailersKey []string + if len(announcedTrailers) > 0 { + announcedTrailersKey = strings.Split(string(announcedTrailers), ",") + } + + res.Header.VisitAll(func(key, value []byte) { + for _, s := range announcedTrailersKey { + if strings.EqualFold(s, strings.TrimSpace(string(key))) { + rw.Header().Add(string(key), string(value)) + return + } + } + + rw.Header().Add(http.TrailerPrefix+string(key), string(value)) + }) + } + + p.connPool.ReleaseConn(co) + + return nil + } + + brl := p.limitReaderPool.Get() + if brl == nil { + brl = &io.LimitedReader{} + } + defer p.limitReaderPool.Put(brl) + + brl.R = br + brl.N = int64(res.Header.ContentLength()) + + b := p.bufferPool.Get() + if b == nil { + b = make([]byte, bufferSize) + } + defer p.bufferPool.Put(b) + + if _, err := io.CopyBuffer(rw, brl, b); err != nil { + co.Close() + return err + } + + p.connPool.ReleaseConn(co) + + return nil +} + +func (p *ReverseProxy) writeRequest(co net.Conn, outReq *fasthttp.Request) error { + bw := p.writerPool.Get() + if bw == nil { + bw = bufio.NewWriterSize(co, bufioSize) + } + defer p.writerPool.Put(bw) + + bw.Reset(co) + + if err := outReq.Write(bw); err != nil { + return err + } + + return bw.Flush() +} + +// isReplayable returns whether the request is replayable. +func isReplayable(req *http.Request) bool { + if req.Body == nil || req.Body == http.NoBody { + switch req.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + return true + } + + // The Idempotency-Key, while non-standard, is widely used to + // mean a POST or other request is idempotent. See + // https://golang.org/issue/19943#issuecomment-421092421 + if _, ok := req.Header["Idempotency-Key"]; ok { + return true + } + + if _, ok := req.Header["X-Idempotency-Key"]; ok { + return true + } + } + + return false +} + +// isGraphic returns whether s is ASCII and printable according to +// https://tools.ietf.org/html/rfc20#section-4.2. +func isGraphic(s string) bool { + for i := range len(s) { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + + return true +} + +type fasthttpHeader interface { + Peek(key string) []byte + Set(key string, value string) + SetBytesV(key string, value []byte) + DelBytes(key []byte) + Del(key string) +} + +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1. +func removeConnectionHeaders(h fasthttpHeader) { + f := h.Peek(fasthttp.HeaderConnection) + for _, sf := range bytes.Split(f, []byte{','}) { + if sf = bytes.TrimSpace(sf); len(sf) > 0 { + h.DelBytes(sf) + } + } +} + +// RFC 7234, section 5.4: Should treat Pragma: no-cache like Cache-Control: no-cache. +func fixPragmaCacheControl(header fasthttpHeader) { + if pragma := header.Peek("Pragma"); bytes.Equal(pragma, []byte("no-cache")) { + if len(header.Peek("Cache-Control")) == 0 { + header.Set("Cache-Control", "no-cache") + } + } +} + +// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive, +// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept, +// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. +// https://tools.ietf.org/html/rfc6455#page-20 +func cleanWebSocketHeaders(headers fasthttpHeader) { + headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key")) + headers.Del("Sec-Websocket-Key") + + headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions")) + headers.Del("Sec-Websocket-Extensions") + + headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept")) + headers.Del("Sec-Websocket-Accept") + + headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol")) + headers.Del("Sec-Websocket-Protocol") + + headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version")) + headers.DelBytes([]byte("Sec-Websocket-Version")) +} diff --git a/pkg/proxy/fast/proxy_test.go b/pkg/proxy/fast/proxy_test.go new file mode 100644 index 000000000..ee75ae0db --- /dev/null +++ b/pkg/proxy/fast/proxy_test.go @@ -0,0 +1,311 @@ +package fast + +import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" + "time" + + "github.com/armon/go-socks5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/config/static" + "github.com/traefik/traefik/v3/pkg/testhelpers" + "github.com/traefik/traefik/v3/pkg/tls/generate" +) + +const ( + proxyHTTP = "http" + proxyHTTPS = "https" + proxySocks5 = "socks" +) + +type authCreds struct { + user string + password string +} + +func TestProxyFromEnvironment(t *testing.T) { + testCases := []struct { + desc string + proxyType string + tls bool + auth *authCreds + }{ + { + desc: "Proxy HTTP with HTTP Backend", + proxyType: proxyHTTP, + }, + { + desc: "Proxy HTTP with HTTP backend and proxy auth", + proxyType: proxyHTTP, + tls: false, + auth: &authCreds{ + user: "user", + password: "password", + }, + }, + { + desc: "Proxy HTTP with HTTPS backend", + proxyType: proxyHTTP, + tls: true, + }, + { + desc: "Proxy HTTP with HTTPS backend and proxy auth", + proxyType: proxyHTTP, + tls: true, + auth: &authCreds{ + user: "user", + password: "password", + }, + }, + { + desc: "Proxy HTTPS with HTTP backend", + proxyType: proxyHTTPS, + }, + { + desc: "Proxy HTTPS with HTTP backend and proxy auth", + proxyType: proxyHTTPS, + tls: false, + auth: &authCreds{ + user: "user", + password: "password", + }, + }, + { + desc: "Proxy HTTPS with HTTPS backend", + proxyType: proxyHTTPS, + tls: true, + }, + { + desc: "Proxy HTTPS with HTTPS backend and proxy auth", + proxyType: proxyHTTPS, + tls: true, + auth: &authCreds{ + user: "user", + password: "password", + }, + }, + { + desc: "Proxy Socks5 with HTTP backend", + proxyType: proxySocks5, + }, + { + desc: "Proxy Socks5 with HTTP backend and proxy auth", + proxyType: proxySocks5, + auth: &authCreds{ + user: "user", + password: "password", + }, + }, + { + desc: "Proxy Socks5 with HTTPS backend", + proxyType: proxySocks5, + tls: true, + }, + { + desc: "Proxy Socks5 with HTTPS backend and proxy auth", + proxyType: proxySocks5, + tls: true, + auth: &authCreds{ + user: "user", + password: "password", + }, + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + backendURL, backendCert := newBackendServer(t, test.tls, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + _, _ = rw.Write([]byte("backend")) + })) + + var proxyCalled bool + proxyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + proxyCalled = true + + if test.auth != nil { + proxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.auth.user+":"+test.auth.password)) + require.Equal(t, proxyAuth, req.Header.Get("Proxy-Authorization")) + } + + if req.Method != http.MethodConnect { + proxy := httputil.NewSingleHostReverseProxy(testhelpers.MustParseURL("http://" + req.Host)) + proxy.ServeHTTP(rw, req) + return + } + + // CONNECT method + conn, err := net.Dial("tcp", req.Host) + require.NoError(t, err) + + hj, ok := rw.(http.Hijacker) + require.True(t, ok) + + rw.WriteHeader(http.StatusOK) + connHj, _, err := hj.Hijack() + require.NoError(t, err) + + go func() { _, _ = io.Copy(connHj, conn) }() + _, _ = io.Copy(conn, connHj) + }) + + var proxyURL string + var proxyCert *x509.Certificate + + switch test.proxyType { + case proxySocks5: + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxyURL = fmt.Sprintf("socks5://%s", ln.Addr()) + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + + proxyCalled = true + + conf := &socks5.Config{} + if test.auth != nil { + conf.Credentials = socks5.StaticCredentials{test.auth.user: test.auth.password} + } + + server, err := socks5.New(conf) + require.NoError(t, err) + + // We are not checking the error, because ServeConn is blocked until the client or the backend + // connection is closed which, in some cases, raises a connection reset by peer error. + _ = server.ServeConn(conn) + + err = ln.Close() + require.NoError(t, err) + }() + + case proxyHTTP: + proxyServer := httptest.NewServer(proxyHandler) + t.Cleanup(proxyServer.Close) + + proxyURL = proxyServer.URL + + case proxyHTTPS: + proxyServer := httptest.NewServer(proxyHandler) + t.Cleanup(proxyServer.Close) + + proxyURL = proxyServer.URL + proxyCert = proxyServer.Certificate() + } + + certPool := x509.NewCertPool() + if proxyCert != nil { + certPool.AddCert(proxyCert) + } + if backendCert != nil { + cert, err := x509.ParseCertificate(backendCert.Certificate[0]) + require.NoError(t, err) + + certPool.AddCert(cert) + } + + builder := NewProxyBuilder(&transportManagerMock{tlsConfig: &tls.Config{RootCAs: certPool}}, static.FastProxyConfig{}) + builder.proxy = func(req *http.Request) (*url.URL, error) { + u, err := url.Parse(proxyURL) + if err != nil { + return nil, err + } + + if test.auth != nil { + u.User = url.UserPassword(test.auth.user, test.auth.password) + } + + return u, nil + } + + reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false) + require.NoError(t, err) + + reverseProxyServer := httptest.NewServer(reverseProxy) + t.Cleanup(reverseProxyServer.Close) + + client := http.Client{Timeout: 5 * time.Second} + + resp, err := client.Get(reverseProxyServer.URL) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, "backend", string(body)) + assert.True(t, proxyCalled) + }) + } +} + +func newCertificate(t *testing.T, domain string) *tls.Certificate { + t.Helper() + + certPEM, keyPEM, err := generate.KeyPair(domain, time.Time{}) + require.NoError(t, err) + + certificate, err := tls.X509KeyPair(certPEM, keyPEM) + require.NoError(t, err) + + return &certificate +} + +func newBackendServer(t *testing.T, isTLS bool, handler http.Handler) (string, *tls.Certificate) { + t.Helper() + + var ln net.Listener + var err error + var cert *tls.Certificate + + scheme := "http" + domain := "backend.localhost" + if isTLS { + scheme = "https" + + cert = newCertificate(t, domain) + + ln, err = tls.Listen("tcp", ":0", &tls.Config{Certificates: []tls.Certificate{*cert}}) + require.NoError(t, err) + } else { + ln, err = net.Listen("tcp", ":0") + require.NoError(t, err) + } + + srv := &http.Server{Handler: handler} + go func() { _ = srv.Serve(ln) }() + + t.Cleanup(func() { _ = srv.Close() }) + + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + backendURL := fmt.Sprintf("%s://%s:%s", scheme, domain, port) + + return backendURL, cert +} + +type transportManagerMock struct { + tlsConfig *tls.Config +} + +func (r *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) { + return r.tlsConfig, nil +} + +func (r *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) { + return &dynamic.ServersTransport{}, nil +} diff --git a/pkg/proxy/fast/proxy_websocket_test.go b/pkg/proxy/fast/proxy_websocket_test.go new file mode 100644 index 000000000..ff2e273e9 --- /dev/null +++ b/pkg/proxy/fast/proxy_websocket_test.go @@ -0,0 +1,693 @@ +package fast + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + gorillawebsocket "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/traefik/traefik/v3/pkg/testhelpers" + "golang.org/x/net/websocket" +) + +func TestWebSocketTCPClose(t *testing.T) { + errChan := make(chan error, 1) + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + + for { + _, _, err := c.ReadMessage() + if err != nil { + errChan <- err + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + + proxyAddr := proxy.Listener.Addr().String() + _, conn, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + ).open() + require.NoError(t, err) + + conn.Close() + + serverErr := <-errChan + + var wsErr *gorillawebsocket.CloseError + require.ErrorAs(t, serverErr, &wsErr) + assert.Equal(t, 1006, wsErr.Code) +} + +func TestWebSocketPingPong(t *testing.T) { + upgrader := gorillawebsocket.Upgrader{ + HandshakeTimeout: 10 * time.Second, + CheckOrigin: func(*http.Request) bool { + return true + }, + } + + mux := http.NewServeMux() + mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) { + ws, err := upgrader.Upgrade(writer, request, nil) + require.NoError(t, err) + + ws.SetPingHandler(func(appData string) error { + err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong")) + require.NoError(t, err) + + return nil + }) + + _, _, _ = ws.ReadMessage() + }) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + serverAddr := proxy.Listener.Addr().String() + + headers := http.Header{} + webSocketURL := "ws://" + serverAddr + "/ws" + headers.Add("Origin", webSocketURL) + + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) + require.NoError(t, err, "Error during Dial with response: %+v", resp) + defer conn.Close() + + goodErr := fmt.Errorf("signal: %s", "Good data") + badErr := fmt.Errorf("signal: %s", "Bad data") + conn.SetPongHandler(func(data string) error { + if data == "PingPong" { + return goodErr + } + + return badErr + }) + + err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second)) + require.NoError(t, err) + + _, _, err = conn.ReadMessage() + + if !errors.Is(err, goodErr) { + require.NoError(t, err) + } +} + +func TestWebSocketEcho(t *testing.T) { + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + msg := make([]byte, 4) + + n, err := conn.Read(msg) + require.NoError(t, err) + + _, err = conn.Write(msg[:n]) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + })) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + serverAddr := proxy.Listener.Addr().String() + + headers := http.Header{} + webSocketURL := "ws://" + serverAddr + "/ws" + headers.Add("Origin", webSocketURL) + + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) + require.NoError(t, err, "Error during Dial with response: %+v", resp) + + err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) + require.NoError(t, err) + + _, msg, err := conn.ReadMessage() + require.NoError(t, err) + + assert.Equal(t, "OK", string(msg)) + + err = conn.Close() + require.NoError(t, err) +} + +func TestWebSocketPassHost(t *testing.T) { + testCases := []struct { + desc string + passHost bool + expected string + }{ + { + desc: "PassHost false", + passHost: false, + }, + { + desc: "PassHost true", + passHost: true, + expected: "example.com", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + req := conn.Request() + + if test.passHost { + require.Equal(t, test.expected, req.Host) + } else { + require.NotEqual(t, test.expected, req.Host) + } + + msg := make([]byte, 4) + + n, err := conn.Read(msg) + require.NoError(t, err) + + _, err = conn.Write(msg[:n]) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + })) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + + serverAddr := proxy.Listener.Addr().String() + + headers := http.Header{} + webSocketURL := "ws://" + serverAddr + "/ws" + headers.Add("Origin", webSocketURL) + headers.Add("Host", "example.com") + + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) + require.NoError(t, err, "Error during Dial with response: %+v", resp) + + err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) + require.NoError(t, err) + + _, msg, err := conn.ReadMessage() + require.NoError(t, err) + + assert.Equal(t, "OK", string(msg)) + + err = conn.Close() + require.NoError(t, err) + }) + } +} + +func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { + return true + }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + withOrigin("http://127.0.0.2"), + ).send() + + require.NoError(t, err) + assert.Equal(t, "ok", resp) +} + +func TestWebSocketRequestWithOrigin(t *testing.T) { + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + _, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("echo"), + withOrigin("http://127.0.0.2"), + ).send() + require.EqualError(t, err, "bad status") + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + require.NoError(t, err) + assert.Equal(t, "ok", resp) +} + +func TestWebSocketRequestWithQueryParams(t *testing.T) { + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + assert.Equal(t, "test", r.URL.Query().Get("query")) + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws?query=test"), + withData("ok"), + ).send() + + require.NoError(t, err) + assert.Equal(t, "ok", resp) +} + +func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + _ = conn.Close() + })) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + })) + defer srv.Close() + + u := parseURI(t, srv.URL) + + f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) { + return net.Dial("tcp", u.Host) + })) + require.NoError(t, err) + + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req.URL = parseURI(t, srv.URL) + w.Header().Set("HEADER-KEY", "HEADER-VALUE") + f.ServeHTTP(w, req) + })) + defer proxy.Close() + + serverAddr := proxy.Listener.Addr().String() + + headers := http.Header{} + webSocketURL := "ws://" + serverAddr + "/ws" + headers.Add("Origin", webSocketURL) + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) + require.NoError(t, err, "Error during Dial with response: %+v", err, resp) + defer conn.Close() + + assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY")) +} + +func TestWebSocketRequestWithEncodedChar(t *testing.T) { + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath()) + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/%3A%2F%2F"), + withData("ok"), + ).send() + + require.NoError(t, err) + assert.Equal(t, "ok", resp) +} + +func TestWebSocketUpgradeFailed(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + })) + defer srv.Close() + + u := parseURI(t, srv.URL) + f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) { + return net.Dial("tcp", u.Host) + })) + require.NoError(t, err) + + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + path := req.URL.Path // keep the original path + + if path != "/ws" { + w.WriteHeader(http.StatusOK) + return + } + + // Set new backend URL + req.URL = parseURI(t, srv.URL) + req.URL.Path = path + f.ServeHTTP(w, req) + })) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout) + + require.NoError(t, err) + defer conn.Close() + + req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil) + require.NoError(t, err) + + req.Header.Add("upgrade", "websocket") + req.Header.Add("Connection", "upgrade") + + err = req.Write(conn) + require.NoError(t, err) + + // First request works with 400 + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, req) + require.NoError(t, err) + + assert.Equal(t, 400, resp.StatusCode) +} + +func TestForwardsWebsocketTraffic(t *testing.T) { + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + _, err := conn.Write([]byte("ok")) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + })) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("echo"), + ).send() + + require.NoError(t, err) + assert.Equal(t, "ok", resp) +} + +func createTLSWebsocketServer() *httptest.Server { + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + return srv +} + +func TestWebSocketTransferTLSConfig(t *testing.T) { + srv := createTLSWebsocketServer() + defer srv.Close() + + proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + defer proxyWithoutTLSConfig.Close() + + proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String() + + _, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + require.EqualError(t, err, "bad status") + + pool := createConnectionPool(srv.URL, &tls.Config{InsecureSkipVerify: true}) + + proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, pool) + defer proxyWithTLSConfig.Close() + + proxyAddr = proxyWithTLSConfig.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + require.NoError(t, err) + assert.Equal(t, "ok", resp) +} + +const dialTimeout = time.Second + +type websocketRequestOpt func(w *websocketRequest) + +func withServer(server string) websocketRequestOpt { + return func(w *websocketRequest) { + w.ServerAddr = server + } +} + +func withPath(path string) websocketRequestOpt { + return func(w *websocketRequest) { + w.Path = path + } +} + +func withData(data string) websocketRequestOpt { + return func(w *websocketRequest) { + w.Data = data + } +} + +func withOrigin(origin string) websocketRequestOpt { + return func(w *websocketRequest) { + w.Origin = origin + } +} + +func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest { + wsrequest := &websocketRequest{} + for _, opt := range opts { + opt(wsrequest) + } + + if wsrequest.Origin == "" { + wsrequest.Origin = "http://" + wsrequest.ServerAddr + } + + if wsrequest.Config == nil { + wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin) + } + + return wsrequest +} + +type websocketRequest struct { + ServerAddr string + Path string + Data string + Origin string + Config *websocket.Config +} + +func (w *websocketRequest) send() (string, error) { + conn, _, err := w.open() + if err != nil { + return "", err + } + defer conn.Close() + + if _, err := conn.Write([]byte(w.Data)); err != nil { + return "", err + } + + msg := make([]byte, 512) + + var n int + n, err = conn.Read(msg) + if err != nil { + return "", err + } + + received := string(msg[:n]) + return received, nil +} + +func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) { + client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout) + if err != nil { + return nil, nil, err + } + + conn, err := websocket.NewClient(w.Config, client) + if err != nil { + return nil, nil, err + } + + return conn, client, err +} + +func parseURI(t *testing.T, uri string) *url.URL { + t.Helper() + + out, err := url.ParseRequestURI(uri) + require.NoError(t, err) + + return out +} + +func createConnectionPool(target string, tlsConfig *tls.Config) *connPool { + u := testhelpers.MustParseURL(target) + return newConnPool(200, 0, func() (net.Conn, error) { + if tlsConfig != nil { + return tls.Dial("tcp", u.Host, tlsConfig) + } + + return net.Dial("tcp", u.Host) + }) +} + +func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptest.Server { + t.Helper() + + u := parseURI(t, uri) + proxy, err := NewReverseProxy(u, nil, false, true, 0, pool) + require.NoError(t, err) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + path := req.URL.Path // keep the original path + // Set new backend URL + req.URL = u + req.URL.Path = path + + proxy.ServeHTTP(w, req) + })) + t.Cleanup(srv.Close) + + return srv +} diff --git a/pkg/proxy/fast/upgrade.go b/pkg/proxy/fast/upgrade.go new file mode 100644 index 000000000..a42fc97d8 --- /dev/null +++ b/pkg/proxy/fast/upgrade.go @@ -0,0 +1,104 @@ +package fast + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "strings" + + "github.com/traefik/traefik/v3/pkg/proxy/httputil" + "github.com/valyala/fasthttp" + "golang.org/x/net/http/httpguts" +) + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} + +func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) { + defer backConn.Close() + + resUpType := upgradeTypeFastHTTP(&res.Header) + + if !strings.EqualFold(reqUpType, resUpType) { + httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + hj, ok := rw.(http.Hijacker) + if !ok { + httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + _ = backConn.Close() + }() + + defer close(backConnCloseCh) + + conn, brw, err := hj.Hijack() + if err != nil { + httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err)) + return + } + defer conn.Close() + + for k, values := range rw.Header() { + for _, v := range values { + res.Header.Add(k, v) + } + } + + if err := res.Header.Write(brw.Writer); err != nil { + httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err)) + return + } + + if err := brw.Flush(); err != nil { + httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err)) + return + } + + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + + return h.Get("Upgrade") +} + +func upgradeTypeFastHTTP(h fasthttpHeader) string { + if !bytes.Contains(h.Peek("Connection"), []byte("Upgrade")) { + return "" + } + + return string(h.Peek("Upgrade")) +} diff --git a/pkg/server/service/bufferpool.go b/pkg/proxy/httputil/bufferpool.go similarity index 57% rename from pkg/server/service/bufferpool.go rename to pkg/proxy/httputil/bufferpool.go index 948611502..7fccd6bc4 100644 --- a/pkg/server/service/bufferpool.go +++ b/pkg/proxy/httputil/bufferpool.go @@ -1,23 +1,25 @@ -package service +package httputil import "sync" -const bufferPoolSize = 32 * 1024 - -func newBufferPool() *bufferPool { - return &bufferPool{ - pool: sync.Pool{ - New: func() interface{} { - return make([]byte, bufferPoolSize) - }, - }, - } -} +const bufferSize = 32 * 1024 type bufferPool struct { pool sync.Pool } +func newBufferPool() *bufferPool { + b := &bufferPool{ + pool: sync.Pool{}, + } + + b.pool.New = func() interface{} { + return make([]byte, bufferSize) + } + + return b +} + func (b *bufferPool) Get() []byte { return b.pool.Get().([]byte) } diff --git a/pkg/proxy/httputil/builder.go b/pkg/proxy/httputil/builder.go new file mode 100644 index 000000000..ee88bd0c3 --- /dev/null +++ b/pkg/proxy/httputil/builder.go @@ -0,0 +1,54 @@ +package httputil + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/metrics" +) + +// TransportManager manages transport used for backend communications. +type TransportManager interface { + Get(name string) (*dynamic.ServersTransport, error) + GetRoundTripper(name string) (http.RoundTripper, error) + GetTLSConfig(name string) (*tls.Config, error) +} + +// ProxyBuilder handles the http.RoundTripper for httputil reverse proxies. +type ProxyBuilder struct { + bufferPool *bufferPool + transportManager TransportManager + semConvMetricsRegistry *metrics.SemConvMetricsRegistry +} + +// NewProxyBuilder creates a new ProxyBuilder. +func NewProxyBuilder(transportManager TransportManager, semConvMetricsRegistry *metrics.SemConvMetricsRegistry) *ProxyBuilder { + return &ProxyBuilder{ + bufferPool: newBufferPool(), + transportManager: transportManager, + semConvMetricsRegistry: semConvMetricsRegistry, + } +} + +// Update does nothing. +func (r *ProxyBuilder) Update(_ map[string]*dynamic.ServersTransport) {} + +// Build builds a new httputil.ReverseProxy with the given configuration. +func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) { + roundTripper, err := r.transportManager.GetRoundTripper(cfgName) + if err != nil { + return nil, fmt.Errorf("getting RoundTripper: %w", err) + } + + if shouldObserve { + // Wrapping the roundTripper with the Tracing roundTripper, + // to handle the reverseProxy client span creation. + roundTripper = newObservabilityRoundTripper(r.semConvMetricsRegistry, roundTripper) + } + + return buildSingleHostProxy(targetURL, passHostHeader, flushInterval, roundTripper, r.bufferPool), nil +} diff --git a/pkg/proxy/httputil/builder_test.go b/pkg/proxy/httputil/builder_test.go new file mode 100644 index 000000000..8033635d2 --- /dev/null +++ b/pkg/proxy/httputil/builder_test.go @@ -0,0 +1,56 @@ +package httputil + +import ( + "crypto/tls" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/testhelpers" +) + +func TestEscapedPath(t *testing.T) { + var gotEscapedPath string + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + gotEscapedPath = req.URL.EscapedPath() + })) + + transportManager := &transportManagerMock{ + roundTrippers: map[string]http.RoundTripper{"default": &http.Transport{}}, + } + + p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, 0) + require.NoError(t, err) + + proxy := httptest.NewServer(http.HandlerFunc(p.ServeHTTP)) + + _, err = http.Get(proxy.URL + "/%3A%2F%2F") + require.NoError(t, err) + + assert.Equal(t, "/%3A%2F%2F", gotEscapedPath) +} + +type transportManagerMock struct { + roundTrippers map[string]http.RoundTripper +} + +func (t *transportManagerMock) GetRoundTripper(name string) (http.RoundTripper, error) { + roundTripper, ok := t.roundTrippers[name] + if !ok { + return nil, errors.New("no transport for " + name) + } + + return roundTripper, nil +} + +func (t *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) { + panic("implement me") +} + +func (t *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) { + panic("implement me") +} diff --git a/pkg/server/service/observability_roundtripper.go b/pkg/proxy/httputil/observability.go similarity index 98% rename from pkg/server/service/observability_roundtripper.go rename to pkg/proxy/httputil/observability.go index c10e028d5..9240f5f7e 100644 --- a/pkg/server/service/observability_roundtripper.go +++ b/pkg/proxy/httputil/observability.go @@ -1,4 +1,4 @@ -package service +package httputil import ( "context" @@ -23,6 +23,13 @@ type wrapper struct { rt http.RoundTripper } +func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper { + return &wrapper{ + semConvMetricRegistry: semConvMetricRegistry, + rt: rt, + } +} + func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) { start := time.Now() var span trace.Span @@ -42,7 +49,7 @@ func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) { var headers http.Header response, err := t.rt.RoundTrip(req) if err != nil { - statusCode = computeStatusCode(err) + statusCode = ComputeStatusCode(err) } if response != nil { statusCode = response.StatusCode @@ -96,10 +103,3 @@ func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) { return response, err } - -func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper { - return &wrapper{ - semConvMetricRegistry: semConvMetricRegistry, - rt: rt, - } -} diff --git a/pkg/server/service/observability_roundtripper_test.go b/pkg/proxy/httputil/observability_test.go similarity index 99% rename from pkg/server/service/observability_roundtripper_test.go rename to pkg/proxy/httputil/observability_test.go index 837839543..ea0721e07 100644 --- a/pkg/server/service/observability_roundtripper_test.go +++ b/pkg/proxy/httputil/observability_test.go @@ -1,4 +1,4 @@ -package service +package httputil import ( "context" diff --git a/pkg/server/service/proxy.go b/pkg/proxy/httputil/proxy.go similarity index 90% rename from pkg/server/service/proxy.go rename to pkg/proxy/httputil/proxy.go index 160ee03a4..e0401ea1f 100644 --- a/pkg/server/service/proxy.go +++ b/pkg/proxy/httputil/proxy.go @@ -1,4 +1,4 @@ -package service +package httputil import ( "context" @@ -27,7 +27,7 @@ func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval ti Transport: roundTripper, FlushInterval: flushInterval, BufferPool: bufferPool, - ErrorHandler: errorHandler, + ErrorHandler: ErrorHandler, } } @@ -93,8 +93,9 @@ func isWebSocketUpgrade(req *http.Request) bool { strings.EqualFold(req.Header.Get("Upgrade"), "websocket") } -func errorHandler(w http.ResponseWriter, req *http.Request, err error) { - statusCode := computeStatusCode(err) +// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request. +func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) { + statusCode := ComputeStatusCode(err) logger := log.Ctx(req.Context()) logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode)) @@ -105,7 +106,8 @@ func errorHandler(w http.ResponseWriter, req *http.Request, err error) { } } -func computeStatusCode(err error) int { +// ComputeStatusCode computes the HTTP status code according to the given error. +func ComputeStatusCode(err error) int { switch { case errors.Is(err, io.EOF): return http.StatusBadGateway diff --git a/pkg/server/service/proxy_websocket_test.go b/pkg/proxy/httputil/proxy_websocket_test.go similarity index 86% rename from pkg/server/service/proxy_websocket_test.go rename to pkg/proxy/httputil/proxy_websocket_test.go index ea66ef57e..dc6fa8c82 100644 --- a/pkg/server/service/proxy_websocket_test.go +++ b/pkg/proxy/httputil/proxy_websocket_test.go @@ -1,4 +1,4 @@ -package service +package httputil import ( "bufio" @@ -8,13 +8,13 @@ import ( "net" "net/http" "net/http/httptest" - "net/url" "testing" "time" gorillawebsocket "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/traefik/traefik/v3/pkg/testhelpers" "golang.org/x/net/websocket" ) @@ -27,6 +27,7 @@ func TestWebSocketTCPClose(t *testing.T) { return } defer c.Close() + for { _, _, err := c.ReadMessage() if err != nil { @@ -71,6 +72,7 @@ func TestWebSocketPingPong(t *testing.T) { ws.SetPingHandler(func(appData string) error { err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong")) require.NoError(t, err) + return nil }) @@ -97,6 +99,7 @@ func TestWebSocketPingPong(t *testing.T) { if data == "PingPong" { return goodErr } + return badErr }) @@ -104,7 +107,6 @@ func TestWebSocketPingPong(t *testing.T) { require.NoError(t, err) _, _, err = conn.ReadMessage() - if !errors.Is(err, goodErr) { require.NoError(t, err) } @@ -114,12 +116,10 @@ func TestWebSocketEcho(t *testing.T) { mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) - _, err := conn.Read(msg) + n, err := conn.Read(msg) require.NoError(t, err) - fmt.Println(string(msg)) - - _, err = conn.Write(msg) + _, err = conn.Write(msg[:n]) require.NoError(t, err) err = conn.Close() @@ -142,7 +142,10 @@ func TestWebSocketEcho(t *testing.T) { err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) require.NoError(t, err) - fmt.Println(conn.ReadMessage()) + _, msg, err := conn.ReadMessage() + require.NoError(t, err) + + assert.Equal(t, "OK", string(msg)) err = conn.Close() require.NoError(t, err) @@ -178,11 +181,10 @@ func TestWebSocketPassHost(t *testing.T) { } msg := make([]byte, 4) - _, err := conn.Read(msg) + n, err := conn.Read(msg) require.NoError(t, err) - fmt.Println(string(msg)) - _, err = conn.Write(msg) + _, err = conn.Write(msg[:n]) require.NoError(t, err) err = conn.Close() @@ -207,7 +209,10 @@ func TestWebSocketPassHost(t *testing.T) { err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) require.NoError(t, err) - fmt.Println(conn.ReadMessage()) + _, msg, err := conn.ReadMessage() + require.NoError(t, err) + + assert.Equal(t, "OK", string(msg)) err = conn.Close() require.NoError(t, err) @@ -216,27 +221,8 @@ func TestWebSocketPassHost(t *testing.T) { } func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { - upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { - return true - }} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer c.Close() - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } - })) - defer srv.Close() + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + srv := createServer(t, upgrader, func(*http.Request) {}) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() @@ -254,25 +240,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { } func TestWebSocketRequestWithOrigin(t *testing.T) { - upgrader := gorillawebsocket.Upgrader{} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer c.Close() - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } - })) - defer srv.Close() + srv := createServer(t, gorillawebsocket.Upgrader{}, func(*http.Request) {}) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() @@ -297,26 +265,9 @@ func TestWebSocketRequestWithOrigin(t *testing.T) { } func TestWebSocketRequestWithQueryParams(t *testing.T) { - upgrader := gorillawebsocket.Upgrader{} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() + srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) { assert.Equal(t, "test", r.URL.Query().Get("query")) - for { - mt, message, err := conn.ReadMessage() - if err != nil { - break - } - err = conn.WriteMessage(mt, message) - if err != nil { - break - } - } - })) - defer srv.Close() + }) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() @@ -341,11 +292,19 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { srv := httptest.NewServer(mux) defer srv.Close() - f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) + transportManager := &transportManagerMock{ + roundTrippers: map[string]http.RoundTripper{ + "default@internal": &http.Transport{}, + }, + } + + p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0) + require.NoError(t, err) + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.URL = parseURI(t, srv.URL) + req.URL = testhelpers.MustParseURL(srv.URL) w.Header().Set("HEADER-KEY", "HEADER-VALUE") - f.ServeHTTP(w, req) + p.ServeHTTP(w, req) })) defer proxy.Close() @@ -363,26 +322,9 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { } func TestWebSocketRequestWithEncodedChar(t *testing.T) { - upgrader := gorillawebsocket.Upgrader{} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() + srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) { assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath()) - for { - mt, message, err := conn.ReadMessage() - if err != nil { - break - } - err = conn.WriteMessage(mt, message) - if err != nil { - break - } - } - })) - defer srv.Close() + }) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) defer proxy.Close() @@ -407,15 +349,23 @@ func TestWebSocketUpgradeFailed(t *testing.T) { srv := httptest.NewServer(mux) defer srv.Close() - f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) + transportManager := &transportManagerMock{ + roundTrippers: map[string]http.RoundTripper{ + "default@internal": &http.Transport{}, + }, + } + + p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0) + require.NoError(t, err) + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path if path == "/ws" { // Set new backend URL - req.URL = parseURI(t, srv.URL) + req.URL = testhelpers.MustParseURL(srv.URL) req.URL.Path = path - f.ServeHTTP(w, req) + p.ServeHTTP(w, req) } else { w.WriteHeader(http.StatusOK) } @@ -629,27 +579,60 @@ func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) { return conn, client, err } -func parseURI(t *testing.T, uri string) *url.URL { - t.Helper() - - out, err := url.ParseRequestURI(uri) - require.NoError(t, err) - return out -} - func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server { t.Helper() - u := parseURI(t, uri) - proxy := buildSingleHostProxy(u, true, 0, transport, nil) + u := testhelpers.MustParseURL(uri) + + transportManager := &transportManagerMock{ + roundTrippers: map[string]http.RoundTripper{"fwd": transport}, + } + + p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, 0) + require.NoError(t, err) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - path := req.URL.Path // keep the original path + // keep the original path + path := req.URL.Path + // Set new backend URL req.URL = u req.URL.Path = path - proxy.ServeHTTP(w, req) + p.ServeHTTP(w, req) })) t.Cleanup(srv.Close) + + return srv +} + +func createServer(t *testing.T, upgrader gorillawebsocket.Upgrader, check func(*http.Request)) *httptest.Server { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("Error during upgrade: %v", err) + return + } + defer conn.Close() + + check(r) + for { + mt, message, err := conn.ReadMessage() + if err != nil { + t.Logf("Error during read: %v", err) + break + } + + err = conn.WriteMessage(mt, message) + if err != nil { + t.Logf("Error during write: %v", err) + break + } + } + })) + t.Cleanup(srv.Close) + return srv } diff --git a/pkg/proxy/smart_builder.go b/pkg/proxy/smart_builder.go new file mode 100644 index 000000000..1abadcca7 --- /dev/null +++ b/pkg/proxy/smart_builder.go @@ -0,0 +1,61 @@ +package proxy + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/config/static" + "github.com/traefik/traefik/v3/pkg/proxy/fast" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" + "github.com/traefik/traefik/v3/pkg/server/service" +) + +// TransportManager manages transport used for backend communications. +type TransportManager interface { + Get(name string) (*dynamic.ServersTransport, error) + GetRoundTripper(name string) (http.RoundTripper, error) + GetTLSConfig(name string) (*tls.Config, error) +} + +// SmartBuilder is a proxy builder which returns a fast proxy or httputil proxy corresponding +// to the ServersTransport configuration. +type SmartBuilder struct { + fastProxyBuilder *fast.ProxyBuilder + proxyBuilder service.ProxyBuilder + + transportManager httputil.TransportManager +} + +// NewSmartBuilder creates and returns a new SmartBuilder instance. +func NewSmartBuilder(transportManager TransportManager, proxyBuilder service.ProxyBuilder, fastProxyConfig static.FastProxyConfig) *SmartBuilder { + return &SmartBuilder{ + fastProxyBuilder: fast.NewProxyBuilder(transportManager, fastProxyConfig), + proxyBuilder: proxyBuilder, + transportManager: transportManager, + } +} + +// Update is the handler called when the dynamic configuration is updated. +func (b *SmartBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) { + b.fastProxyBuilder.Update(newConfigs) +} + +// Build builds an HTTP proxy for the given URL using the ServersTransport with the given name. +func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) { + serversTransport, err := b.transportManager.Get(configName) + if err != nil { + return nil, fmt.Errorf("getting ServersTransport: %w", err) + } + + // The fast proxy implementation cannot handle HTTP/2 requests for now. + // For the https scheme we cannot guess if the backend communication will use HTTP2, + // thus we check if HTTP/2 is disabled to use the fast proxy implementation when this is possible. + if targetURL.Scheme == "h2c" || (targetURL.Scheme == "https" && !serversTransport.DisableHTTP2) { + return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, flushInterval) + } + return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader) +} diff --git a/pkg/proxy/smart_builder_test.go b/pkg/proxy/smart_builder_test.go new file mode 100644 index 000000000..d1c29ddd8 --- /dev/null +++ b/pkg/proxy/smart_builder_test.go @@ -0,0 +1,113 @@ +package proxy + +import ( + "encoding/pem" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/config/static" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" + "github.com/traefik/traefik/v3/pkg/server/service" + "github.com/traefik/traefik/v3/pkg/testhelpers" + "github.com/traefik/traefik/v3/pkg/types" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func TestSmartBuilder_Build(t *testing.T) { + tests := []struct { + desc string + serversTransport dynamic.ServersTransport + fastProxyConfig static.FastProxyConfig + https bool + h2c bool + wantFastProxy bool + }{ + { + desc: "fastproxy", + fastProxyConfig: static.FastProxyConfig{Debug: true}, + wantFastProxy: true, + }, + { + desc: "fastproxy with https and without DisableHTTP2", + https: true, + fastProxyConfig: static.FastProxyConfig{Debug: true}, + wantFastProxy: false, + }, + { + desc: "fastproxy with https and DisableHTTP2", + https: true, + serversTransport: dynamic.ServersTransport{DisableHTTP2: true}, + fastProxyConfig: static.FastProxyConfig{Debug: true}, + wantFastProxy: true, + }, + { + desc: "fastproxy with h2c", + h2c: true, + fastProxyConfig: static.FastProxyConfig{Debug: true}, + wantFastProxy: false, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + var callCount int + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if test.wantFastProxy { + assert.Contains(t, r.Header, "X-Traefik-Fast-Proxy") + } else { + assert.NotContains(t, r.Header, "X-Traefik-Fast-Proxy") + } + }) + + var server *httptest.Server + + if test.https { + server = httptest.NewUnstartedServer(handler) + server.EnableHTTP2 = false + server.StartTLS() + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[0].Certificate[0]}) + test.serversTransport.RootCAs = []types.FileOrContent{ + types.FileOrContent(certPEM), + } + } else { + server = httptest.NewServer(h2c.NewHandler(handler, &http2.Server{})) + } + t.Cleanup(func() { + server.Close() + }) + + targetURL := testhelpers.MustParseURL(server.URL) + if test.h2c { + targetURL.Scheme = "h2c" + } + + serversTransports := map[string]*dynamic.ServersTransport{ + "test": &test.serversTransport, + } + + transportManager := service.NewTransportManager(nil) + transportManager.Update(serversTransports) + + httpProxyBuilder := httputil.NewProxyBuilder(transportManager, nil) + proxyBuilder := NewSmartBuilder(transportManager, httpProxyBuilder, test.fastProxyConfig) + + proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, time.Second) + require.NoError(t, err) + + rw := httptest.NewRecorder() + proxyHandler.ServeHTTP(rw, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) + + assert.Equal(t, 1, callCount) + }) + } +} diff --git a/pkg/server/router/router_test.go b/pkg/server/router/router_test.go index 9b1cab388..b545bd2da 100644 --- a/pkg/server/router/router_test.go +++ b/pkg/server/router/router_test.go @@ -2,10 +2,12 @@ package router import ( "context" + "crypto/tls" "io" "math" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -18,7 +20,7 @@ import ( "github.com/traefik/traefik/v3/pkg/server/middleware" "github.com/traefik/traefik/v3/pkg/server/service" "github.com/traefik/traefik/v3/pkg/testhelpers" - "github.com/traefik/traefik/v3/pkg/tls" + traefiktls "github.com/traefik/traefik/v3/pkg/tls" ) func TestRouterManager_Get(t *testing.T) { @@ -309,11 +311,12 @@ func TestRouterManager_Get(t *testing.T) { }, }) - roundTripperManager := service.NewRoundTripperManager(nil) - roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) - serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager) + transportManager := service.NewTransportManager(nil) + transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) + + serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, proxyBuilderMock{}) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) - tlsManager := tls.NewManager() + tlsManager := traefiktls.NewManager() routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) @@ -340,7 +343,7 @@ func TestRuntimeConfiguration(t *testing.T) { serviceConfig map[string]*dynamic.Service routerConfig map[string]*dynamic.Router middlewareConfig map[string]*dynamic.Middleware - tlsOptions map[string]tls.Options + tlsOptions map[string]traefiktls.Options expectedError int }{ { @@ -597,7 +600,7 @@ func TestRuntimeConfiguration(t *testing.T) { TLS: &dynamic.RouterTLSConfig{}, }, }, - tlsOptions: map[string]tls.Options{}, + tlsOptions: map[string]traefiktls.Options{}, expectedError: 1, }, { @@ -624,9 +627,9 @@ func TestRuntimeConfiguration(t *testing.T) { }, }, }, - tlsOptions: map[string]tls.Options{ + tlsOptions: map[string]traefiktls.Options{ "broken-tlsOption": { - ClientAuth: tls.ClientAuth{ + ClientAuth: traefiktls.ClientAuth{ ClientAuthType: "foobar", }, }, @@ -655,9 +658,9 @@ func TestRuntimeConfiguration(t *testing.T) { TLS: &dynamic.RouterTLSConfig{}, }, }, - tlsOptions: map[string]tls.Options{ + tlsOptions: map[string]traefiktls.Options{ "default": { - ClientAuth: tls.ClientAuth{ + ClientAuth: traefiktls.ClientAuth{ ClientAuthType: "foobar", }, }, @@ -682,11 +685,12 @@ func TestRuntimeConfiguration(t *testing.T) { }, }) - roundTripperManager := service.NewRoundTripperManager(nil) - roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) - serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager) + transportManager := service.NewTransportManager(nil) + transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) + + serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, proxyBuilderMock{}) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) - tlsManager := tls.NewManager() + tlsManager := traefiktls.NewManager() tlsManager.UpdateConfigs(context.Background(), nil, test.tlsOptions, nil) routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) @@ -759,11 +763,12 @@ func TestProviderOnMiddlewares(t *testing.T) { }, }) - roundTripperManager := service.NewRoundTripperManager(nil) - roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) - serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager) + transportManager := service.NewTransportManager(nil) + transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) + + serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) - tlsManager := tls.NewManager() + tlsManager := traefiktls.NewManager() routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) @@ -775,14 +780,22 @@ func TestProviderOnMiddlewares(t *testing.T) { assert.Equal(t, []string{"m1@docker", "m2@docker", "m1@file"}, rtConf.Middlewares["chain@docker"].Chain.Middlewares) } -type staticRoundTripperGetter struct { +type staticTransportManager struct { res *http.Response } -func (s staticRoundTripperGetter) Get(name string) (http.RoundTripper, error) { +func (s staticTransportManager) GetRoundTripper(_ string) (http.RoundTripper, error) { return &staticTransport{res: s.res}, nil } +func (s staticTransportManager) GetTLSConfig(_ string) (*tls.Config, error) { + panic("implement me") +} + +func (s staticTransportManager) Get(_ string) (*dynamic.ServersTransport, error) { + panic("implement me") +} + type staticTransport struct { res *http.Response } @@ -829,9 +842,9 @@ func BenchmarkRouterServe(b *testing.B) { }, }) - serviceManager := service.NewManager(rtConf.Services, nil, nil, staticRoundTripperGetter{res}) + serviceManager := service.NewManager(rtConf.Services, nil, nil, staticTransportManager{res}, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) - tlsManager := tls.NewManager() + tlsManager := traefiktls.NewManager() routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) @@ -871,7 +884,7 @@ func BenchmarkService(b *testing.B) { }, }) - serviceManager := service.NewManager(rtConf.Services, nil, nil, staticRoundTripperGetter{res}) + serviceManager := service.NewManager(rtConf.Services, nil, nil, staticTransportManager{res}, nil) w := httptest.NewRecorder() req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) @@ -881,3 +894,13 @@ func BenchmarkService(b *testing.B) { handler.ServeHTTP(w, req) } } + +type proxyBuilderMock struct{} + +func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil +} + +func (p proxyBuilderMock) Update(_ map[string]*dynamic.ServersTransport) { + panic("implement me") +} diff --git a/pkg/server/routerfactory_test.go b/pkg/server/routerfactory_test.go index 836885781..6c0861d38 100644 --- a/pkg/server/routerfactory_test.go +++ b/pkg/server/routerfactory_test.go @@ -3,7 +3,9 @@ package server import ( "net/http" "net/http/httptest" + "net/url" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/traefik/traefik/v3/pkg/config/dynamic" @@ -48,9 +50,10 @@ func TestReuseService(t *testing.T) { ), ) - roundTripperManager := service.NewRoundTripperManager(nil) - roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) - managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil) + transportManager := service.NewTransportManager(nil) + transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) + + managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, proxyBuilderMock{}, nil) tlsManager := tls.NewManager() dialerManager := tcp.NewDialerManager(nil) @@ -184,9 +187,10 @@ func TestServerResponseEmptyBackend(t *testing.T) { }, } - roundTripperManager := service.NewRoundTripperManager(nil) - roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) - managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil) + transportManager := service.NewTransportManager(nil) + transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) + + managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, proxyBuilderMock{}, nil) tlsManager := tls.NewManager() dialerManager := tcp.NewDialerManager(nil) @@ -228,9 +232,10 @@ func TestInternalServices(t *testing.T) { ), ) - roundTripperManager := service.NewRoundTripperManager(nil) - roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) - managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil) + transportManager := service.NewTransportManager(nil) + transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) + + managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, nil, nil) tlsManager := tls.NewManager() dialerManager := tcp.NewDialerManager(nil) @@ -246,3 +251,13 @@ func TestInternalServices(t *testing.T) { assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code") } + +type proxyBuilderMock struct{} + +func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil +} + +func (p proxyBuilderMock) Update(_ map[string]*dynamic.ServersTransport) { + panic("implement me") +} diff --git a/pkg/server/service/managerfactory.go b/pkg/server/service/managerfactory.go index abff2d5a6..43e2189b3 100644 --- a/pkg/server/service/managerfactory.go +++ b/pkg/server/service/managerfactory.go @@ -17,7 +17,8 @@ import ( type ManagerFactory struct { observabilityMgr *middleware.ObservabilityMgr - roundTripperManager *RoundTripperManager + transportManager *TransportManager + proxyBuilder ProxyBuilder api func(configuration *runtime.Configuration) http.Handler restHandler http.Handler @@ -30,12 +31,13 @@ type ManagerFactory struct { } // NewManagerFactory creates a new ManagerFactory. -func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *safe.Pool, observabilityMgr *middleware.ObservabilityMgr, roundTripperManager *RoundTripperManager, acmeHTTPHandler http.Handler) *ManagerFactory { +func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *safe.Pool, observabilityMgr *middleware.ObservabilityMgr, transportManager *TransportManager, proxyBuilder ProxyBuilder, acmeHTTPHandler http.Handler) *ManagerFactory { factory := &ManagerFactory{ - observabilityMgr: observabilityMgr, - routinesPool: routinesPool, - roundTripperManager: roundTripperManager, - acmeHTTPHandler: acmeHTTPHandler, + observabilityMgr: observabilityMgr, + routinesPool: routinesPool, + transportManager: transportManager, + proxyBuilder: proxyBuilder, + acmeHTTPHandler: acmeHTTPHandler, } if staticConfiguration.API != nil { @@ -73,7 +75,7 @@ func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *s // Build creates a service manager. func (f *ManagerFactory) Build(configuration *runtime.Configuration) *InternalHandlers { - svcManager := NewManager(configuration.Services, f.observabilityMgr, f.routinesPool, f.roundTripperManager) + svcManager := NewManager(configuration.Services, f.observabilityMgr, f.routinesPool, f.transportManager, f.proxyBuilder) var apiHandler http.Handler if f.api != nil { diff --git a/pkg/server/service/proxy_test.go b/pkg/server/service/proxy_test.go deleted file mode 100644 index a8dfd758c..000000000 --- a/pkg/server/service/proxy_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package service - -import ( - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/traefik/traefik/v3/pkg/testhelpers" -) - -type staticTransport struct { - res *http.Response -} - -func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { - return t.res, nil -} - -func BenchmarkProxy(b *testing.B) { - res := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("")), - } - - w := httptest.NewRecorder() - req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) - - pool := newBufferPool() - handler := buildSingleHostProxy(req.URL, false, 0, &staticTransport{res}, pool) - - b.ReportAllocs() - for range b.N { - handler.ServeHTTP(w, req) - } -} diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index 46682e25a..e37ecdf09 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -9,7 +9,6 @@ import ( "hash/fnv" "math/rand" "net/http" - "net/http/httputil" "net/url" "reflect" "strings" @@ -25,6 +24,7 @@ import ( "github.com/traefik/traefik/v3/pkg/middlewares/capture" metricsMiddle "github.com/traefik/traefik/v3/pkg/middlewares/metrics" "github.com/traefik/traefik/v3/pkg/middlewares/observability" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" "github.com/traefik/traefik/v3/pkg/safe" "github.com/traefik/traefik/v3/pkg/server/cookie" "github.com/traefik/traefik/v3/pkg/server/middleware" @@ -40,17 +40,18 @@ const ( defaultMaxBodySize int64 = -1 ) -// RoundTripperGetter is a roundtripper getter interface. -type RoundTripperGetter interface { - Get(name string) (http.RoundTripper, error) +// ProxyBuilder builds reverse proxy handlers. +type ProxyBuilder interface { + Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) + Update(configs map[string]*dynamic.ServersTransport) } // Manager The service manager. type Manager struct { - routinePool *safe.Pool - observabilityMgr *middleware.ObservabilityMgr - bufferPool httputil.BufferPool - roundTripperManager RoundTripperGetter + routinePool *safe.Pool + observabilityMgr *middleware.ObservabilityMgr + transportManager httputil.TransportManager + proxyBuilder ProxyBuilder services map[string]http.Handler configs map[string]*runtime.ServiceInfo @@ -59,16 +60,16 @@ type Manager struct { } // NewManager creates a new Manager. -func NewManager(configs map[string]*runtime.ServiceInfo, observabilityMgr *middleware.ObservabilityMgr, routinePool *safe.Pool, roundTripperManager RoundTripperGetter) *Manager { +func NewManager(configs map[string]*runtime.ServiceInfo, observabilityMgr *middleware.ObservabilityMgr, routinePool *safe.Pool, transportManager httputil.TransportManager, proxyBuilder ProxyBuilder) *Manager { return &Manager{ - routinePool: routinePool, - observabilityMgr: observabilityMgr, - bufferPool: newBufferPool(), - roundTripperManager: roundTripperManager, - services: make(map[string]http.Handler), - configs: configs, - healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker), - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + routinePool: routinePool, + observabilityMgr: observabilityMgr, + transportManager: transportManager, + proxyBuilder: proxyBuilder, + services: make(map[string]http.Handler), + configs: configs, + healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker), + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -298,9 +299,9 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName logger.Debug().Msg("Creating load-balancer") // TODO: should we keep this config value as Go is now handling stream response correctly? - flushInterval := dynamic.DefaultFlushInterval + flushInterval := time.Duration(dynamic.DefaultFlushInterval) if service.ResponseForwarding != nil { - flushInterval = service.ResponseForwarding.FlushInterval + flushInterval = time.Duration(service.ResponseForwarding.FlushInterval) } if len(service.ServersTransport) > 0 { @@ -317,11 +318,6 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName passHostHeader = *service.PassHostHeader } - roundTripper, err := m.roundTripperManager.Get(service.ServersTransport) - if err != nil { - return nil, err - } - lb := wrr.New(service.Sticky, service.HealthCheck != nil) healthCheckTargets := make(map[string]*url.URL) @@ -341,14 +337,12 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName) - if m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName) { - // Wrapping the roundTripper with the Tracing roundTripper, - // to handle the reverseProxy client span creation. - roundTripper = newObservabilityRoundTripper(m.observabilityMgr.SemConvMetricsRegistry(), roundTripper) + shouldObserve := m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName) + proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, flushInterval) + if err != nil { + return nil, fmt.Errorf("error building proxy for server URL %s: %w", server.URL, err) } - proxy := buildSingleHostProxy(target, passHostHeader, time.Duration(flushInterval), roundTripper, m.bufferPool) - // Prevents from enabling observability for internal resources. if m.observabilityMgr.ShouldAddAccessLogs(qualifiedSvcName) { @@ -393,6 +387,11 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName } if service.HealthCheck != nil { + roundTripper, err := m.transportManager.GetRoundTripper(service.ServersTransport) + if err != nil { + return nil, fmt.Errorf("getting RoundTripper: %w", err) + } + m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker( ctx, m.observabilityMgr.MetricsRegistry(), diff --git a/pkg/server/service/service_test.go b/pkg/server/service/service_test.go index 786a1aded..98404ccf9 100644 --- a/pkg/server/service/service_test.go +++ b/pkg/server/service/service_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "crypto/tls" "io" "net/http" "net/http/httptest" @@ -14,13 +15,14 @@ import ( "github.com/stretchr/testify/require" "github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/config/runtime" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" "github.com/traefik/traefik/v3/pkg/server/provider" "github.com/traefik/traefik/v3/pkg/testhelpers" ) func TestGetLoadBalancer(t *testing.T) { sm := Manager{ - roundTripperManager: newRtMock(), + transportManager: &transportManagerMock{}, } testCases := []struct { @@ -40,14 +42,14 @@ func TestGetLoadBalancer(t *testing.T) { }, }, }, - fwd: &MockForwarder{}, + fwd: &forwarderMock{}, expectError: true, }, { desc: "Succeeds when there are no servers", serviceName: "test", service: &dynamic.ServersLoadBalancer{}, - fwd: &MockForwarder{}, + fwd: &forwarderMock{}, expectError: false, }, { @@ -56,7 +58,7 @@ func TestGetLoadBalancer(t *testing.T) { service: &dynamic.ServersLoadBalancer{ Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}}, }, - fwd: &MockForwarder{}, + fwd: &forwarderMock{}, expectError: false, }, } @@ -79,11 +81,8 @@ func TestGetLoadBalancer(t *testing.T) { } func TestGetLoadBalancerServiceHandler(t *testing.T) { - sm := NewManager(nil, nil, nil, &RoundTripperManager{ - roundTrippers: map[string]http.RoundTripper{ - "default@internal": http.DefaultTransport, - }, - }) + pb := httputil.NewProxyBuilder(&transportManagerMock{}, nil) + sm := NewManager(nil, nil, nil, transportManagerMock{}, pb) server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-From", "first") @@ -139,7 +138,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { desc: "Load balances between the two servers", serviceName: "test", service: &dynamic.ServersLoadBalancer{ - PassHostHeader: Bool(true), + PassHostHeader: boolPtr(true), Servers: []dynamic.Server{ { URL: server1.URL, @@ -254,7 +253,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { desc: "PassHost doesn't pass the host instead of the IP", serviceName: "test", service: &dynamic.ServersLoadBalancer{ - PassHostHeader: Bool(false), + PassHostHeader: boolPtr(false), Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}}, Servers: []dynamic.Server{ { @@ -359,11 +358,8 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { // This test is an adapted version of net/http/httputil.Test1xxResponses test. func Test1xxResponses(t *testing.T) { - sm := NewManager(nil, nil, nil, &RoundTripperManager{ - roundTrippers: map[string]http.RoundTripper{ - "default@internal": http.DefaultTransport, - }, - }) + pb := httputil.NewProxyBuilder(&transportManagerMock{}, nil) + sm := NewManager(nil, nil, nil, &transportManagerMock{}, pb) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.Header() @@ -499,11 +495,7 @@ func TestManager_Build(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - manager := NewManager(test.configs, nil, nil, &RoundTripperManager{ - roundTrippers: map[string]http.RoundTripper{ - "default@internal": http.DefaultTransport, - }, - }) + manager := NewManager(test.configs, nil, nil, &transportManagerMock{}, nil) ctx := context.Background() if len(test.providerName) > 0 { @@ -526,30 +518,30 @@ func TestMultipleTypeOnBuildHTTP(t *testing.T) { }, } - manager := NewManager(services, nil, nil, &RoundTripperManager{ - roundTrippers: map[string]http.RoundTripper{ - "default@internal": http.DefaultTransport, - }, - }) + manager := NewManager(services, nil, nil, &transportManagerMock{}, nil) _, err := manager.BuildHTTP(context.Background(), "test@file") assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead") } -func Bool(v bool) *bool { return &v } +func boolPtr(v bool) *bool { return &v } -type MockForwarder struct{} +type forwarderMock struct{} -func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) { +func (forwarderMock) ServeHTTP(http.ResponseWriter, *http.Request) { panic("not available") } -type rtMock struct{} +type transportManagerMock struct{} -func newRtMock() RoundTripperGetter { - return &rtMock{} +func (t transportManagerMock) GetRoundTripper(_ string) (http.RoundTripper, error) { + return &http.Transport{}, nil } -func (r *rtMock) Get(_ string) (http.RoundTripper, error) { - return http.DefaultTransport, nil +func (t transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) { + return nil, nil +} + +func (t transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) { + return &dynamic.ServersTransport{}, nil } diff --git a/pkg/server/service/smart_roundtripper.go b/pkg/server/service/smart_roundtripper.go index 93b8c62d6..a1d7b3f1a 100644 --- a/pkg/server/service/smart_roundtripper.go +++ b/pkg/server/service/smart_roundtripper.go @@ -11,6 +11,15 @@ import ( "golang.org/x/net/http2" ) +type h2cTransportWrapper struct { + *http2.Transport +} + +func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + return t.Transport.RoundTrip(req) +} + func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) { transportHTTP1 := transport.Clone() diff --git a/pkg/server/service/roundtripper.go b/pkg/server/service/transport.go similarity index 65% rename from pkg/server/service/roundtripper.go rename to pkg/server/service/transport.go index 06b6473c0..ecce54010 100644 --- a/pkg/server/service/roundtripper.go +++ b/pkg/server/service/transport.go @@ -22,52 +22,45 @@ import ( "github.com/traefik/traefik/v3/pkg/config/dynamic" traefiktls "github.com/traefik/traefik/v3/pkg/tls" "github.com/traefik/traefik/v3/pkg/types" - "golang.org/x/net/http2" ) -type h2cTransportWrapper struct { - *http2.Transport -} - -func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) { - req.URL.Scheme = "http" - return t.Transport.RoundTrip(req) -} - // SpiffeX509Source allows to retrieve a x509 SVID and bundle. type SpiffeX509Source interface { x509svid.Source x509bundle.Source } -// NewRoundTripperManager creates a new RoundTripperManager. -func NewRoundTripperManager(spiffeX509Source SpiffeX509Source) *RoundTripperManager { - return &RoundTripperManager{ - roundTrippers: make(map[string]http.RoundTripper), - configs: make(map[string]*dynamic.ServersTransport), - spiffeX509Source: spiffeX509Source, - } -} - -// RoundTripperManager handles roundtripper for the reverse proxy. -type RoundTripperManager struct { +// TransportManager handles transports for backend communication. +type TransportManager struct { rtLock sync.RWMutex roundTrippers map[string]http.RoundTripper configs map[string]*dynamic.ServersTransport + tlsConfigs map[string]*tls.Config spiffeX509Source SpiffeX509Source } -// Update updates the roundtrippers configurations. -func (r *RoundTripperManager) Update(newConfigs map[string]*dynamic.ServersTransport) { - r.rtLock.Lock() - defer r.rtLock.Unlock() +// NewTransportManager creates a new TransportManager. +func NewTransportManager(spiffeX509Source SpiffeX509Source) *TransportManager { + return &TransportManager{ + roundTrippers: make(map[string]http.RoundTripper), + configs: make(map[string]*dynamic.ServersTransport), + tlsConfigs: make(map[string]*tls.Config), + spiffeX509Source: spiffeX509Source, + } +} - for configName, config := range r.configs { +// Update updates the transport configurations. +func (t *TransportManager) Update(newConfigs map[string]*dynamic.ServersTransport) { + t.rtLock.Lock() + defer t.rtLock.Unlock() + + for configName, config := range t.configs { newConfig, ok := newConfigs[configName] if !ok { - delete(r.configs, configName) - delete(r.roundTrippers, configName) + delete(t.configs, configName) + delete(t.roundTrippers, configName) + delete(t.tlsConfigs, configName) continue } @@ -76,50 +69,133 @@ func (r *RoundTripperManager) Update(newConfigs map[string]*dynamic.ServersTrans } var err error - r.roundTrippers[configName], err = r.createRoundTripper(newConfig) + + var tlsConfig *tls.Config + if tlsConfig, err = t.createTLSConfig(newConfig); err != nil { + log.Error().Err(err).Msgf("Could not configure HTTP Transport %s TLS configuration, fallback on default TLS config", configName) + } + t.tlsConfigs[configName] = tlsConfig + + t.roundTrippers[configName], err = t.createRoundTripper(newConfig, tlsConfig) if err != nil { log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", configName) - r.roundTrippers[configName] = http.DefaultTransport + t.roundTrippers[configName] = http.DefaultTransport } } for newConfigName, newConfig := range newConfigs { - if _, ok := r.configs[newConfigName]; ok { + if _, ok := t.configs[newConfigName]; ok { continue } var err error - r.roundTrippers[newConfigName], err = r.createRoundTripper(newConfig) + + var tlsConfig *tls.Config + if tlsConfig, err = t.createTLSConfig(newConfig); err != nil { + log.Error().Err(err).Msgf("Could not configure HTTP Transport %s TLS configuration, fallback on default TLS config", newConfigName) + } + t.tlsConfigs[newConfigName] = tlsConfig + + t.roundTrippers[newConfigName], err = t.createRoundTripper(newConfig, tlsConfig) if err != nil { log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", newConfigName) - r.roundTrippers[newConfigName] = http.DefaultTransport + t.roundTrippers[newConfigName] = http.DefaultTransport } } - r.configs = newConfigs + t.configs = newConfigs } -// Get gets a roundtripper by name. -func (r *RoundTripperManager) Get(name string) (http.RoundTripper, error) { +// GetRoundTripper gets a roundtripper corresponding to the given transport name. +func (t *TransportManager) GetRoundTripper(name string) (http.RoundTripper, error) { if len(name) == 0 { name = "default@internal" } - r.rtLock.RLock() - defer r.rtLock.RUnlock() + t.rtLock.RLock() + defer t.rtLock.RUnlock() - if rt, ok := r.roundTrippers[name]; ok { + if rt, ok := t.roundTrippers[name]; ok { return rt, nil } return nil, fmt.Errorf("servers transport not found %s", name) } +// Get gets transport by name. +func (t *TransportManager) Get(name string) (*dynamic.ServersTransport, error) { + if len(name) == 0 { + name = "default@internal" + } + + t.rtLock.RLock() + defer t.rtLock.RUnlock() + + if rt, ok := t.configs[name]; ok { + return rt, nil + } + + return nil, fmt.Errorf("servers transport not found %s", name) +} + +// GetTLSConfig gets a TLS config corresponding to the given transport name. +func (t *TransportManager) GetTLSConfig(name string) (*tls.Config, error) { + if len(name) == 0 { + name = "default@internal" + } + + t.rtLock.RLock() + defer t.rtLock.RUnlock() + + if rt, ok := t.tlsConfigs[name]; ok { + return rt, nil + } + + return nil, fmt.Errorf("tls config not found %s", name) +} + +func (t *TransportManager) createTLSConfig(cfg *dynamic.ServersTransport) (*tls.Config, error) { + var config *tls.Config + if cfg.Spiffe != nil { + if t.spiffeX509Source == nil { + return nil, errors.New("SPIFFE is enabled for this transport, but not configured") + } + + spiffeAuthorizer, err := buildSpiffeAuthorizer(cfg.Spiffe) + if err != nil { + return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err) + } + + config = tlsconfig.MTLSClientConfig(t.spiffeX509Source, t.spiffeX509Source, spiffeAuthorizer) + } + + if cfg.InsecureSkipVerify || len(cfg.RootCAs) > 0 || len(cfg.ServerName) > 0 || len(cfg.Certificates) > 0 || cfg.PeerCertURI != "" { + if config != nil { + return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time") + } + + config = &tls.Config{ + ServerName: cfg.ServerName, + InsecureSkipVerify: cfg.InsecureSkipVerify, + RootCAs: createRootCACertPool(cfg.RootCAs), + Certificates: cfg.Certificates.GetCertificates(), + } + + if cfg.PeerCertURI != "" { + config.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return traefiktls.VerifyPeerCertificate(cfg.PeerCertURI, config, rawCerts) + } + } + } + + return config, nil +} + // createRoundTripper creates an http.RoundTripper configured with the Transport configuration settings. // For the settings that can't be configured in Traefik it uses the default http.Transport settings. // An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost in Traefik at this point in time. // Setting this value to the default of 100 could lead to confusing behavior and backwards compatibility issues. -func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) (http.RoundTripper, error) { +func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tlsConfig *tls.Config) (http.RoundTripper, error) { if cfg == nil { return nil, errors.New("no transport configuration given") } @@ -142,6 +218,7 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) ExpectContinueTimeout: 1 * time.Second, ReadBufferSize: 64 * 1024, WriteBufferSize: 64 * 1024, + TLSClientConfig: tlsConfig, } if cfg.ForwardingTimeouts != nil { @@ -149,41 +226,9 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) transport.IdleConnTimeout = time.Duration(cfg.ForwardingTimeouts.IdleConnTimeout) } - if cfg.Spiffe != nil { - if r.spiffeX509Source == nil { - return nil, errors.New("SPIFFE is enabled for this transport, but not configured") - } - - spiffeAuthorizer, err := buildSpiffeAuthorizer(cfg.Spiffe) - if err != nil { - return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err) - } - - transport.TLSClientConfig = tlsconfig.MTLSClientConfig(r.spiffeX509Source, r.spiffeX509Source, spiffeAuthorizer) - } - - if cfg.InsecureSkipVerify || len(cfg.RootCAs) > 0 || len(cfg.ServerName) > 0 || len(cfg.Certificates) > 0 || cfg.PeerCertURI != "" { - if transport.TLSClientConfig != nil { - return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time") - } - - transport.TLSClientConfig = &tls.Config{ - ServerName: cfg.ServerName, - InsecureSkipVerify: cfg.InsecureSkipVerify, - RootCAs: createRootCACertPool(cfg.RootCAs), - Certificates: cfg.Certificates.GetCertificates(), - } - - if cfg.PeerCertURI != "" { - transport.TLSClientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - return traefiktls.VerifyPeerCertificate(cfg.PeerCertURI, transport.TLSClientConfig, rawCerts) - } - } - } - // Return directly HTTP/1.1 transport when HTTP/2 is disabled if cfg.DisableHTTP2 { - return &KerberosRoundTripper{ + return &kerberosRoundTripper{ OriginalRoundTripper: transport, new: func() http.RoundTripper { return transport.Clone() @@ -195,7 +240,7 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) if err != nil { return nil, err } - return &KerberosRoundTripper{ + return &kerberosRoundTripper{ OriginalRoundTripper: rt, new: func() http.RoundTripper { return rt.Clone() @@ -203,11 +248,6 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) }, nil } -type KerberosRoundTripper struct { - new func() http.RoundTripper - OriginalRoundTripper http.RoundTripper -} - type stickyRoundTripper struct { RoundTripper http.RoundTripper } @@ -220,7 +260,12 @@ func AddTransportOnContext(ctx context.Context) context.Context { return context.WithValue(ctx, transportKey, &stickyRoundTripper{}) } -func (k *KerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { +type kerberosRoundTripper struct { + new func() http.RoundTripper + OriginalRoundTripper http.RoundTripper +} + +func (k *kerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { value, ok := request.Context().Value(transportKey).(*stickyRoundTripper) if !ok { return k.OriginalRoundTripper.RoundTrip(request) diff --git a/pkg/server/service/roundtripper_test.go b/pkg/server/service/transport_test.go similarity index 96% rename from pkg/server/service/roundtripper_test.go rename to pkg/server/service/transport_test.go index 0b1e80719..952c0ec7e 100644 --- a/pkg/server/service/roundtripper_test.go +++ b/pkg/server/service/transport_test.go @@ -141,7 +141,7 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) { srv.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} srv.StartTLS() - rtManager := NewRoundTripperManager(nil) + transportManager := NewTransportManager(nil) dynamicConf := map[string]*dynamic.ServersTransport{ "test": { @@ -151,9 +151,9 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) { } for range 10 { - rtManager.Update(dynamicConf) + transportManager.Update(dynamicConf) - tr, err := rtManager.Get("test") + tr, err := transportManager.GetRoundTripper("test") require.NoError(t, err) client := http.Client{Transport: tr} @@ -173,9 +173,9 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) { }, } - rtManager.Update(dynamicConf) + transportManager.Update(dynamicConf) - tr, err := rtManager.Get("test") + tr, err := transportManager.GetRoundTripper("test") require.NoError(t, err) client := http.Client{Transport: tr} @@ -209,7 +209,7 @@ func TestMTLS(t *testing.T) { } srv.StartTLS() - rtManager := NewRoundTripperManager(nil) + transportManager := NewTransportManager(nil) dynamicConf := map[string]*dynamic.ServersTransport{ "test": { @@ -227,9 +227,9 @@ func TestMTLS(t *testing.T) { }, } - rtManager.Update(dynamicConf) + transportManager.Update(dynamicConf) - tr, err := rtManager.Get("test") + tr, err := transportManager.GetRoundTripper("test") require.NoError(t, err) client := http.Client{Transport: tr} @@ -348,7 +348,7 @@ func TestSpiffeMTLS(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - rtManager := NewRoundTripperManager(test.clientSource) + transportManager := NewTransportManager(test.clientSource) dynamicConf := map[string]*dynamic.ServersTransport{ "test": { @@ -356,9 +356,9 @@ func TestSpiffeMTLS(t *testing.T) { }, } - rtManager.Update(dynamicConf) + transportManager.Update(dynamicConf) - tr, err := rtManager.Get("test") + tr, err := transportManager.GetRoundTripper("test") require.NoError(t, err) client := http.Client{Transport: tr} @@ -415,7 +415,7 @@ func TestDisableHTTP2(t *testing.T) { srv.EnableHTTP2 = test.serverHTTP2 srv.StartTLS() - rtManager := NewRoundTripperManager(nil) + transportManager := NewTransportManager(nil) dynamicConf := map[string]*dynamic.ServersTransport{ "test": { @@ -424,9 +424,9 @@ func TestDisableHTTP2(t *testing.T) { }, } - rtManager.Update(dynamicConf) + transportManager.Update(dynamicConf) - tr, err := rtManager.Get("test") + tr, err := transportManager.GetRoundTripper("test") require.NoError(t, err) client := http.Client{Transport: tr} @@ -593,7 +593,7 @@ func TestKerberosRoundTripper(t *testing.T) { origCount := 0 dedicatedCount := 0 - rt := KerberosRoundTripper{ + rt := kerberosRoundTripper{ new: func() http.RoundTripper { return roundTripperFn(func(req *http.Request) (*http.Response, error) { dedicatedCount++