Introduce a fast proxy mode to improve HTTP/1.1 performances with backends

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
Kevin Pollet 2024-09-26 11:00:05 +02:00 committed by GitHub
parent a6db1cac37
commit f8a78b3b25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 3173 additions and 378 deletions

View file

@ -229,7 +229,7 @@ issues:
text: 'struct-tag: unknown option ''inline'' in JSON tag' text: 'struct-tag: unknown option ''inline'' in JSON tag'
linters: linters:
- revive - revive
- path: pkg/server/service/bufferpool.go - path: pkg/proxy/httputil/bufferpool.go
text: 'SA6002: argument should be pointer-like to avoid allocations' text: 'SA6002: argument should be pointer-like to avoid allocations'
- path: pkg/server/middleware/middlewares.go - path: pkg/server/middleware/middlewares.go
text: "Function 'buildConstructor' has too many statements" text: "Function 'buildConstructor' has too many statements"

View file

@ -37,6 +37,8 @@ import (
"github.com/traefik/traefik/v3/pkg/provider/aggregator" "github.com/traefik/traefik/v3/pkg/provider/aggregator"
"github.com/traefik/traefik/v3/pkg/provider/tailscale" "github.com/traefik/traefik/v3/pkg/provider/tailscale"
"github.com/traefik/traefik/v3/pkg/provider/traefik" "github.com/traefik/traefik/v3/pkg/provider/traefik"
"github.com/traefik/traefik/v3/pkg/proxy"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/safe" "github.com/traefik/traefik/v3/pkg/safe"
"github.com/traefik/traefik/v3/pkg/server" "github.com/traefik/traefik/v3/pkg/server"
"github.com/traefik/traefik/v3/pkg/server/middleware" "github.com/traefik/traefik/v3/pkg/server/middleware"
@ -281,10 +283,16 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err
log.Info().Msg("Successfully obtained SPIFFE SVID.") log.Info().Msg("Successfully obtained SPIFFE SVID.")
} }
roundTripperManager := service.NewRoundTripperManager(spiffeX509Source) transportManager := service.NewTransportManager(spiffeX509Source)
var proxyBuilder service.ProxyBuilder = httputil.NewProxyBuilder(transportManager, semConvMetricRegistry)
if staticConfiguration.Experimental != nil && staticConfiguration.Experimental.FastProxy != nil {
proxyBuilder = proxy.NewSmartBuilder(transportManager, proxyBuilder, *staticConfiguration.Experimental.FastProxy)
}
dialerManager := tcp.NewDialerManager(spiffeX509Source) dialerManager := tcp.NewDialerManager(spiffeX509Source)
acmeHTTPHandler := getHTTPChallengeHandler(acmeProviders, httpChallengeProvider) acmeHTTPHandler := getHTTPChallengeHandler(acmeProviders, httpChallengeProvider)
managerFactory := service.NewManagerFactory(*staticConfiguration, routinesPool, observabilityMgr, roundTripperManager, acmeHTTPHandler) managerFactory := service.NewManagerFactory(*staticConfiguration, routinesPool, observabilityMgr, transportManager, proxyBuilder, acmeHTTPHandler)
// Router factory // Router factory
@ -318,7 +326,8 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err
// Server Transports // Server Transports
watcher.AddListener(func(conf dynamic.Configuration) { watcher.AddListener(func(conf dynamic.Configuration) {
roundTripperManager.Update(conf.HTTP.ServersTransports) transportManager.Update(conf.HTTP.ServersTransports)
proxyBuilder.Update(conf.HTTP.ServersTransports)
dialerManager.Update(conf.TCP.ServersTransports) dialerManager.Update(conf.TCP.ServersTransports)
}) })

View file

@ -228,6 +228,12 @@ WriteTimeout is the maximum duration before timing out writes of the response. I
`--entrypoints.<name>.udp.timeout`: `--entrypoints.<name>.udp.timeout`:
Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```) Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```)
`--experimental.fastproxy`:
Enable the FastProxy implementation. (Default: ```false```)
`--experimental.fastproxy.debug`:
Enable debug mode for the FastProxy implementation. (Default: ```false```)
`--experimental.kubernetesgateway`: `--experimental.kubernetesgateway`:
(Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```) (Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```)

View file

@ -228,6 +228,12 @@ WriteTimeout is the maximum duration before timing out writes of the response. I
`TRAEFIK_ENTRYPOINTS_<NAME>_UDP_TIMEOUT`: `TRAEFIK_ENTRYPOINTS_<NAME>_UDP_TIMEOUT`:
Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```) Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```)
`TRAEFIK_EXPERIMENTAL_FASTPROXY`:
Enable the FastProxy implementation. (Default: ```false```)
`TRAEFIK_EXPERIMENTAL_FASTPROXY_DEBUG`:
Enable debug mode for the FastProxy implementation. (Default: ```false```)
`TRAEFIK_EXPERIMENTAL_KUBERNETESGATEWAY`: `TRAEFIK_EXPERIMENTAL_KUBERNETESGATEWAY`:
(Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```) (Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```)

View file

@ -509,6 +509,8 @@
[experimental.localPlugins.LocalDescriptor1.settings] [experimental.localPlugins.LocalDescriptor1.settings]
envs = ["foobar", "foobar"] envs = ["foobar", "foobar"]
mounts = ["foobar", "foobar"] mounts = ["foobar", "foobar"]
[experimental.fastProxy]
debug = true
[core] [core]
defaultRuleSyntax = "foobar" defaultRuleSyntax = "foobar"

View file

@ -572,6 +572,8 @@ experimental:
mounts: mounts:
- foobar - foobar
- foobar - foobar
fastProxy:
debug: true
kubernetesGateway: true kubernetesGateway: true
core: core:
defaultRuleSyntax: foobar defaultRuleSyntax: foobar

View file

@ -0,0 +1,41 @@
---
title: "Traefik FastProxy Experimental Configuration"
description: "This section of the Traefik Proxy documentation explains how to use the new FastProxy option."
---
# Traefik FastProxy Experimental Configuration
## Overview
This guide provides instructions on how to configure and use the new experimental `fastProxy` static configuration option in Traefik.
The `fastProxy` option introduces a high-performance reverse proxy designed to enhance the performance of routing.
!!! info "Limitations"
Please note that the new fast proxy implementation does not work with HTTP/2.
This means that when a H2C or HTTPS request with [HTTP2 enabled](../routing/services/index.md#disablehttp2) is sent to a backend, the fallback proxy is the regular one.
Additionnaly, observability features like tracing and OTEL semconv metrics are not supported for the moment.
!!! warning "Experimental"
The `fastProxy` option is currently experimental and subject to change in future releases.
Use with caution in production environments.
### Enabling FastProxy
The fastProxy option is a static configuration parameter.
To enable it, you need to configure it in your Traefik static configuration
```yaml tab="File (YAML)"
experimental:
fastProxy: {}
```
```toml tab="File (TOML)"
[experimental.fastProxy]
```
```bash tab="CLI"
--experimental.fastProxy
```

View file

@ -163,6 +163,7 @@ nav:
- 'Overview': 'observability/tracing/overview.md' - 'Overview': 'observability/tracing/overview.md'
- 'OpenTelemetry': 'observability/tracing/opentelemetry.md' - 'OpenTelemetry': 'observability/tracing/opentelemetry.md'
- 'User Guides': - 'User Guides':
- 'FastProxy': 'user-guides/fastproxy.md'
- 'Kubernetes and Let''s Encrypt': 'user-guides/crd-acme/index.md' - 'Kubernetes and Let''s Encrypt': 'user-guides/crd-acme/index.md'
- 'gRPC Examples': 'user-guides/grpc.md' - 'gRPC Examples': 'user-guides/grpc.md'
- 'Docker': - 'Docker':

8
go.mod
View file

@ -6,7 +6,7 @@ require (
github.com/BurntSushi/toml v1.4.0 github.com/BurntSushi/toml v1.4.0
github.com/Masterminds/sprig/v3 v3.2.3 github.com/Masterminds/sprig/v3 v3.2.3
github.com/abbot/go-http-auth v0.0.0-00010101000000-000000000000 // No tag on the repo. github.com/abbot/go-http-auth v0.0.0-00010101000000-000000000000 // No tag on the repo.
github.com/andybalholm/brotli v1.0.6 github.com/andybalholm/brotli v1.1.0
github.com/aws/aws-sdk-go v1.44.327 github.com/aws/aws-sdk-go v1.44.327
github.com/cenkalti/backoff/v4 v4.3.0 github.com/cenkalti/backoff/v4 v4.3.0
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // No tag on the repo. github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // No tag on the repo.
@ -102,6 +102,11 @@ require (
sigs.k8s.io/yaml v1.4.0 sigs.k8s.io/yaml v1.4.0
) )
require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/valyala/fasthttp v1.55.0
)
require ( require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect
dario.cat/mergo v1.0.0 // indirect dario.cat/mergo v1.0.0 // indirect
@ -315,6 +320,7 @@ require (
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect
github.com/transip/gotransip/v6 v6.23.0 // indirect github.com/transip/gotransip/v6 v6.23.0 // indirect
github.com/ultradns/ultradns-go-sdk v1.6.1-20231103022937-8589b6a // indirect github.com/ultradns/ultradns-go-sdk v1.6.1-20231103022937-8589b6a // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/vinyldns/go-vinyldns v0.9.16 // indirect github.com/vinyldns/go-vinyldns v0.9.16 // indirect
github.com/vultr/govultr/v3 v3.9.0 // indirect github.com/vultr/govultr/v3 v3.9.0 // indirect
github.com/x448/float16 v0.8.4 // indirect github.com/x448/float16 v0.8.4 // indirect

7
go.sum
View file

@ -98,8 +98,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/aliyun/alibaba-cloud-sdk-go v1.62.712 h1:lM7JnA9dEdDFH9XOgRNQMDTQnOjlLkDTNA7c0aWTQ30= github.com/aliyun/alibaba-cloud-sdk-go v1.62.712 h1:lM7JnA9dEdDFH9XOgRNQMDTQnOjlLkDTNA7c0aWTQ30=
github.com/aliyun/alibaba-cloud-sdk-go v1.62.712/go.mod h1:SOSDHfe1kX91v3W5QiBsWSLqeLxImobbMX1mxrFHsVQ= github.com/aliyun/alibaba-cloud-sdk-go v1.62.712/go.mod h1:SOSDHfe1kX91v3W5QiBsWSLqeLxImobbMX1mxrFHsVQ=
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
@ -1041,7 +1041,10 @@ github.com/unrolled/secure v1.0.9 h1:BWRuEb1vDrBFFDdbCnKkof3gZ35I/bnHGyt0LB0TNyQ
github.com/unrolled/secure v1.0.9/go.mod h1:fO+mEan+FLB0CdEnHf6Q4ZZVNqG+5fuLFnP8p0BXDPI= github.com/unrolled/secure v1.0.9/go.mod h1:fO+mEan+FLB0CdEnHf6Q4ZZVNqG+5fuLFnP8p0BXDPI=
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8=
github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM=
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ= github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ=

View file

@ -0,0 +1,35 @@
[global]
checkNewVersion = false
sendAnonymousUsage = false
[log]
level = "DEBUG"
noColor = true
[entryPoints]
[entryPoints.web]
address = ":8000"
[api]
insecure = true
[providers.file]
filename = "{{ .SelfFilename }}"
[experimental]
[experimental.fastProxy]
debug = true
## dynamic configuration ##
[http.routers]
[http.routers.router1]
entrypoints = ["web"]
service = "service1"
rule = "PathPrefix(`/`)"
[http.services]
[http.services.service1]
[http.services.service1.loadBalancer]
[[http.services.service1.loadBalancer.servers]]
url = "{{ .Server }}"

View file

@ -65,6 +65,32 @@ func (s *SimpleSuite) TestSimpleDefaultConfig() {
require.NoError(s.T(), err) require.NoError(s.T(), err)
} }
func (s *SimpleSuite) TestSimpleFastProxy() {
var callCount int
srv1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Contains(s.T(), req.Header, "X-Traefik-Fast-Proxy")
callCount++
}))
defer srv1.Close()
file := s.adaptFile("fixtures/simple_fastproxy.toml", struct {
Server string
}{
Server: srv1.URL,
})
s.traefikCmd(withConfigFile(file), "--log.level=DEBUG")
// wait for traefik
err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("127.0.0.1"))
require.NoError(s.T(), err)
err = try.GetRequest("http://127.0.0.1:8000/", time.Second)
require.NoError(s.T(), err)
assert.GreaterOrEqual(s.T(), 1, callCount)
}
func (s *SimpleSuite) TestWithWebConfig() { func (s *SimpleSuite) TestWithWebConfig() {
s.cmdTraefik(withConfigFile("fixtures/simple_web.toml")) s.cmdTraefik(withConfigFile("fixtures/simple_web.toml"))

View file

@ -7,6 +7,13 @@ type Experimental struct {
Plugins map[string]plugins.Descriptor `description:"Plugins configuration." json:"plugins,omitempty" toml:"plugins,omitempty" yaml:"plugins,omitempty" export:"true"` Plugins map[string]plugins.Descriptor `description:"Plugins configuration." json:"plugins,omitempty" toml:"plugins,omitempty" yaml:"plugins,omitempty" export:"true"`
LocalPlugins map[string]plugins.LocalDescriptor `description:"Local plugins configuration." json:"localPlugins,omitempty" toml:"localPlugins,omitempty" yaml:"localPlugins,omitempty" export:"true"` LocalPlugins map[string]plugins.LocalDescriptor `description:"Local plugins configuration." json:"localPlugins,omitempty" toml:"localPlugins,omitempty" yaml:"localPlugins,omitempty" export:"true"`
FastProxy *FastProxyConfig `description:"Enable the FastProxy implementation." json:"fastProxy,omitempty" toml:"fastProxy,omitempty" yaml:"fastProxy,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
// Deprecated: KubernetesGateway provider is not an experimental feature starting with v3.1. Please remove its usage from the static configuration. // Deprecated: KubernetesGateway provider is not an experimental feature starting with v3.1. Please remove its usage from the static configuration.
KubernetesGateway bool `description:"(Deprecated) Allow the Kubernetes gateway api provider usage." json:"kubernetesGateway,omitempty" toml:"kubernetesGateway,omitempty" yaml:"kubernetesGateway,omitempty" export:"true"` KubernetesGateway bool `description:"(Deprecated) Allow the Kubernetes gateway api provider usage." json:"kubernetesGateway,omitempty" toml:"kubernetesGateway,omitempty" yaml:"kubernetesGateway,omitempty" export:"true"`
} }
// FastProxyConfig holds the FastProxy configuration.
type FastProxyConfig struct {
Debug bool `description:"Enable debug mode for the FastProxy implementation." json:"debug,omitempty" toml:"debug,omitempty" yaml:"debug,omitempty" export:"true"`
}

129
pkg/proxy/fast/builder.go Normal file
View file

@ -0,0 +1,129 @@
package fast
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"reflect"
"time"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/static"
)
// TransportManager manages transport used for backend communications.
type TransportManager interface {
Get(name string) (*dynamic.ServersTransport, error)
GetTLSConfig(name string) (*tls.Config, error)
}
// ProxyBuilder handles the connection pools for the FastProxy proxies.
type ProxyBuilder struct {
debug bool
transportManager TransportManager
// lock isn't needed because ProxyBuilder is not called concurrently.
pools map[string]map[string]*connPool
proxy func(*http.Request) (*url.URL, error)
// not goroutine safe.
configs map[string]*dynamic.ServersTransport
}
// NewProxyBuilder creates a new ProxyBuilder.
func NewProxyBuilder(transportManager TransportManager, config static.FastProxyConfig) *ProxyBuilder {
return &ProxyBuilder{
debug: config.Debug,
transportManager: transportManager,
pools: make(map[string]map[string]*connPool),
proxy: http.ProxyFromEnvironment,
configs: make(map[string]*dynamic.ServersTransport),
}
}
// Update updates all the round-tripper corresponding to the given configs.
// This method must not be used concurrently.
func (r *ProxyBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
for configName := range r.configs {
if _, ok := newConfigs[configName]; !ok {
for _, c := range r.pools[configName] {
c.Close()
}
delete(r.pools, configName)
}
}
for newConfigName, newConfig := range newConfigs {
if !reflect.DeepEqual(newConfig, r.configs[newConfigName]) {
for _, c := range r.pools[newConfigName] {
c.Close()
}
delete(r.pools, newConfigName)
}
}
r.configs = newConfigs
}
// Build builds a new ReverseProxy with the given configuration.
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader bool) (http.Handler, error) {
proxyURL, err := r.proxy(&http.Request{URL: targetURL})
if err != nil {
return nil, fmt.Errorf("getting proxy: %w", err)
}
cfg, err := r.transportManager.Get(cfgName)
if err != nil {
return nil, fmt.Errorf("getting ServersTransport: %w", err)
}
var responseHeaderTimeout time.Duration
if cfg.ForwardingTimeouts != nil {
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
}
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
if err != nil {
return nil, fmt.Errorf("getting TLS config: %w", err)
}
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, responseHeaderTimeout, pool)
}
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
pool, ok := r.pools[cfgName]
if !ok {
pool = make(map[string]*connPool)
r.pools[cfgName] = pool
}
if connPool, ok := pool[targetURL.String()]; ok {
return connPool
}
idleConnTimeout := 90 * time.Second
dialTimeout := 30 * time.Second
if config.ForwardingTimeouts != nil {
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
}
proxyDialer := newDialer(dialerConfig{
DialKeepAlive: 0,
DialTimeout: dialTimeout,
HTTP: true,
TLS: targetURL.Scheme == "https",
ProxyURL: proxyURL,
}, tlsConfig)
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) {
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
})
r.pools[cfgName][targetURL.String()] = connPool
return connPool
}

163
pkg/proxy/fast/connpool.go Normal file
View file

@ -0,0 +1,163 @@
package fast
import (
"fmt"
"net"
"time"
"github.com/rs/zerolog/log"
)
// conn is an enriched net.Conn.
type conn struct {
net.Conn
idleAt time.Time // the last time it was marked as idle.
idleTimeout time.Duration
}
func (c *conn) isExpired() bool {
expTime := c.idleAt.Add(c.idleTimeout)
return c.idleTimeout > 0 && time.Now().After(expTime)
}
// connPool is a net.Conn pool implementation using channels.
type connPool struct {
dialer func() (net.Conn, error)
idleConns chan *conn
idleConnTimeout time.Duration
ticker *time.Ticker
doneCh chan struct{}
}
// newConnPool creates a new connPool.
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
c := &connPool{
dialer: dialer,
idleConns: make(chan *conn, maxIdleConn),
idleConnTimeout: idleConnTimeout,
doneCh: make(chan struct{}),
}
if idleConnTimeout > 0 {
c.ticker = time.NewTicker(c.idleConnTimeout / 2)
go func() {
for {
select {
case <-c.ticker.C:
c.cleanIdleConns()
case <-c.doneCh:
return
}
}
}()
}
return c
}
// Close closes stop the cleanIdleConn goroutine.
func (c *connPool) Close() {
if c.idleConnTimeout > 0 {
close(c.doneCh)
c.ticker.Stop()
}
}
// AcquireConn returns an idle net.Conn from the pool.
func (c *connPool) AcquireConn() (*conn, error) {
for {
co, err := c.acquireConn()
if err != nil {
return nil, err
}
if !co.isExpired() {
return co, nil
}
// As the acquired conn is expired we can close it
// without putting it again into the pool.
if err := co.Close(); err != nil {
log.Debug().
Err(err).
Msg("Unexpected error while releasing the connection")
}
}
}
// ReleaseConn releases the given net.Conn to the pool.
func (c *connPool) ReleaseConn(co *conn) {
co.idleAt = time.Now()
c.releaseConn(co)
}
// cleanIdleConns is a routine cleaning the expired connections at a regular basis.
func (c *connPool) cleanIdleConns() {
for {
select {
case co := <-c.idleConns:
if !co.isExpired() {
c.releaseConn(co)
return
}
if err := co.Close(); err != nil {
log.Debug().
Err(err).
Msg("Unexpected error while releasing the connection")
}
default:
return
}
}
}
func (c *connPool) acquireConn() (*conn, error) {
select {
case co := <-c.idleConns:
return co, nil
default:
errCh := make(chan error, 1)
go c.askForNewConn(errCh)
select {
case co := <-c.idleConns:
return co, nil
case err := <-errCh:
return nil, err
}
}
}
func (c *connPool) releaseConn(co *conn) {
select {
case c.idleConns <- co:
// Hitting the default case means that we have reached the maximum number of idle
// connections, so we can close it.
default:
if err := co.Close(); err != nil {
log.Debug().
Err(err).
Msg("Unexpected error while releasing the connection")
}
}
}
func (c *connPool) askForNewConn(errCh chan<- error) {
co, err := c.dialer()
if err != nil {
errCh <- fmt.Errorf("create conn: %w", err)
return
}
c.releaseConn(&conn{
Conn: co,
idleAt: time.Now(),
idleTimeout: c.idleConnTimeout,
})
}

View file

@ -0,0 +1,184 @@
package fast
import (
"net"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnPool_ConnReuse(t *testing.T) {
testCases := []struct {
desc string
poolFn func(pool *connPool)
expected int
}{
{
desc: "One connection",
poolFn: func(pool *connPool) {
c1, _ := pool.AcquireConn()
pool.ReleaseConn(c1)
},
expected: 1,
},
{
desc: "Two connections with release",
poolFn: func(pool *connPool) {
c1, _ := pool.AcquireConn()
pool.ReleaseConn(c1)
c2, _ := pool.AcquireConn()
pool.ReleaseConn(c2)
},
expected: 1,
},
{
desc: "Two concurrent connections",
poolFn: func(pool *connPool) {
c1, _ := pool.AcquireConn()
c2, _ := pool.AcquireConn()
pool.ReleaseConn(c1)
pool.ReleaseConn(c2)
},
expected: 2,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var connAlloc int
dialer := func() (net.Conn, error) {
connAlloc++
return &net.TCPConn{}, nil
}
pool := newConnPool(2, 0, dialer)
test.poolFn(pool)
assert.Equal(t, test.expected, connAlloc)
})
}
}
func TestConnPool_MaxIdleConn(t *testing.T) {
testCases := []struct {
desc string
poolFn func(pool *connPool)
maxIdleConn int
expected int
}{
{
desc: "One connection",
poolFn: func(pool *connPool) {
c1, _ := pool.AcquireConn()
pool.ReleaseConn(c1)
},
maxIdleConn: 1,
expected: 1,
},
{
desc: "Multiple connections with defered release",
poolFn: func(pool *connPool) {
for range 7 {
c, _ := pool.AcquireConn()
defer pool.ReleaseConn(c)
}
},
maxIdleConn: 5,
expected: 5,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var keepOpenedConn int
dialer := func() (net.Conn, error) {
keepOpenedConn++
return &mockConn{closeFn: func() error {
keepOpenedConn--
return nil
}}, nil
}
pool := newConnPool(test.maxIdleConn, 0, dialer)
test.poolFn(pool)
assert.Equal(t, test.expected, keepOpenedConn)
})
}
}
func TestGC(t *testing.T) {
var isDestroyed bool
pools := map[string]*connPool{}
dialer := func() (net.Conn, error) {
c := &mockConn{closeFn: func() error {
return nil
}}
return c, nil
}
pools["test"] = newConnPool(10, 1*time.Second, dialer)
runtime.SetFinalizer(pools["test"], func(p *connPool) {
isDestroyed = true
})
c, err := pools["test"].AcquireConn()
require.NoError(t, err)
pools["test"].ReleaseConn(c)
pools["test"].Close()
delete(pools, "test")
runtime.GC()
require.True(t, isDestroyed)
}
type mockConn struct {
closeFn func() error
}
func (m *mockConn) Read(_ []byte) (n int, err error) {
panic("implement me")
}
func (m *mockConn) Write(_ []byte) (n int, err error) {
panic("implement me")
}
func (m *mockConn) Close() error {
if m.closeFn != nil {
return m.closeFn()
}
return nil
}
func (m *mockConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mockConn) RemoteAddr() net.Addr {
panic("implement me")
}
func (m *mockConn) SetDeadline(_ time.Time) error {
panic("implement me")
}
func (m *mockConn) SetReadDeadline(_ time.Time) error {
panic("implement me")
}
func (m *mockConn) SetWriteDeadline(_ time.Time) error {
panic("implement me")
}

195
pkg/proxy/fast/dialer.go Normal file
View file

@ -0,0 +1,195 @@
package fast
import (
"bufio"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"net"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/net/proxy"
)
const (
schemeHTTP = "http"
schemeHTTPS = "https"
schemeSocks5 = "socks5"
)
type dialer interface {
Dial(network, addr string) (c net.Conn, err error)
}
type dialerFunc func(network, addr string) (net.Conn, error)
func (d dialerFunc) Dial(network, addr string) (net.Conn, error) {
return d(network, addr)
}
type dialerConfig struct {
DialKeepAlive time.Duration
DialTimeout time.Duration
ProxyURL *url.URL
HTTP bool
TLS bool
}
func newDialer(cfg dialerConfig, tlsConfig *tls.Config) dialer {
if cfg.ProxyURL == nil {
return buildDialer(cfg, tlsConfig, cfg.TLS)
}
proxyDialer := buildDialer(cfg, tlsConfig, cfg.ProxyURL.Scheme == "https")
proxyAddr := addrFromURL(cfg.ProxyURL)
switch {
case cfg.ProxyURL.Scheme == schemeSocks5:
var auth *proxy.Auth
if u := cfg.ProxyURL.User; u != nil {
auth = &proxy.Auth{User: u.Username()}
auth.Password, _ = u.Password()
}
// SOCKS5 implementation do not return errors.
socksDialer, _ := proxy.SOCKS5("tcp", proxyAddr, auth, proxyDialer)
return dialerFunc(func(network, targetAddr string) (net.Conn, error) {
co, err := socksDialer.Dial("tcp", targetAddr)
if err != nil {
return nil, err
}
if cfg.TLS {
c := &tls.Config{}
if tlsConfig != nil {
c = tlsConfig.Clone()
}
if c.ServerName == "" {
host, _, _ := net.SplitHostPort(targetAddr)
c.ServerName = host
}
return tls.Client(co, c), nil
}
return co, nil
})
case cfg.HTTP && !cfg.TLS:
// Nothing to do the Proxy-Authorization header will be added by the ReverseProxy.
default:
hdr := make(http.Header)
if u := cfg.ProxyURL.User; u != nil {
username := u.Username()
password, _ := u.Password()
auth := username + ":" + password
hdr.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
}
return dialerFunc(func(network, targetAddr string) (net.Conn, error) {
conn, err := proxyDialer.Dial("tcp", proxyAddr)
if err != nil {
return nil, err
}
connectReq := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Opaque: targetAddr},
Host: targetAddr,
Header: hdr,
}
connectCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
var resp *http.Response
// Write the CONNECT request & read the response.
go func() {
defer close(didReadResponse)
err = connectReq.Write(conn)
if err != nil {
return
}
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(conn)
resp, err = http.ReadResponse(br, connectReq)
}()
select {
case <-connectCtx.Done():
conn.Close()
<-didReadResponse
return nil, connectCtx.Err()
case <-didReadResponse:
// resp or err now set
}
if err != nil {
conn.Close()
return nil, err
}
if resp.StatusCode != http.StatusOK {
_, statusText, ok := strings.Cut(resp.Status, " ")
conn.Close()
if !ok {
return nil, errors.New("unknown status code")
}
return nil, errors.New(statusText)
}
c := &tls.Config{}
if tlsConfig != nil {
c = tlsConfig.Clone()
}
if c.ServerName == "" {
host, _, _ := net.SplitHostPort(targetAddr)
c.ServerName = host
}
return tls.Client(conn, c), nil
})
}
return dialerFunc(func(network, addr string) (net.Conn, error) {
return proxyDialer.Dial("tcp", proxyAddr)
})
}
func buildDialer(cfg dialerConfig, tlsConfig *tls.Config, isTLS bool) dialer {
dialer := &net.Dialer{
Timeout: cfg.DialTimeout,
KeepAlive: cfg.DialKeepAlive,
}
if !isTLS {
return dialer
}
return &tls.Dialer{
NetDialer: dialer,
Config: tlsConfig,
}
}
func addrFromURL(u *url.URL) string {
addr := u.Host
if u.Port() == "" {
switch u.Scheme {
case schemeHTTP:
return addr + ":80"
case schemeHTTPS:
return addr + ":443"
}
}
return addr
}

553
pkg/proxy/fast/proxy.go Normal file
View file

@ -0,0 +1,553 @@
package fast
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/valyala/fasthttp"
"golang.org/x/net/http/httpguts"
)
const (
bufferSize = 32 * 1024
bufioSize = 64 * 1024
)
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
type pool[T any] struct {
pool sync.Pool
}
func (p *pool[T]) Get() T {
if tmp := p.pool.Get(); tmp != nil {
return tmp.(T)
}
var res T
return res
}
func (p *pool[T]) Put(x T) {
p.pool.Put(x)
}
type buffConn struct {
*bufio.Reader
net.Conn
}
func (b buffConn) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}
type writeDetector struct {
net.Conn
written bool
}
func (w *writeDetector) Write(p []byte) (int, error) {
n, err := w.Conn.Write(p)
if n > 0 {
w.written = true
}
return n, err
}
type writeFlusher struct {
io.Writer
}
func (w *writeFlusher) Write(b []byte) (int, error) {
n, err := w.Writer.Write(b)
if f, ok := w.Writer.(http.Flusher); ok {
f.Flush()
}
return n, err
}
type timeoutError struct {
error
}
func (t timeoutError) Timeout() bool {
return true
}
func (t timeoutError) Temporary() bool {
return false
}
// ReverseProxy is the FastProxy reverse proxy implementation.
type ReverseProxy struct {
debug bool
connPool *connPool
bufferPool pool[[]byte]
readerPool pool[*bufio.Reader]
writerPool pool[*bufio.Writer]
limitReaderPool pool[*io.LimitedReader]
proxyAuth string
targetURL *url.URL
passHostHeader bool
responseHeaderTimeout time.Duration
}
// NewReverseProxy creates a new ReverseProxy.
func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
var proxyAuth string
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
}
return &ReverseProxy{
debug: debug,
passHostHeader: passHostHeader,
targetURL: targetURL,
proxyAuth: proxyAuth,
connPool: connPool,
responseHeaderTimeout: responseHeaderTimeout,
}, nil
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Body != nil {
defer req.Body.Close()
}
outReq := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(outReq)
// This is not required as the headers are already normalized by net/http.
outReq.Header.DisableNormalizing()
for k, v := range req.Header {
for _, s := range v {
outReq.Header.Add(k, s)
}
}
removeConnectionHeaders(&outReq.Header)
for _, header := range hopHeaders {
outReq.Header.Del(header)
}
if p.proxyAuth != "" {
outReq.Header.Set("Proxy-Authorization", p.proxyAuth)
}
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outReq.Header.Set("Te", "trailers")
}
if p.debug {
outReq.Header.Set("X-Traefik-Fast-Proxy", "enabled")
}
reqUpType := upgradeType(req.Header)
if !isGraphic(reqUpType) {
proxyhttputil.ErrorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
return
}
if reqUpType != "" {
outReq.Header.Set("Connection", "Upgrade")
outReq.Header.Set("Upgrade", reqUpType)
if reqUpType == "websocket" {
cleanWebSocketHeaders(&outReq.Header)
}
}
u2 := new(url.URL)
*u2 = *req.URL
u2.Scheme = p.targetURL.Scheme
u2.Host = p.targetURL.Host
u := req.URL
if req.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(req.RequestURI)
if err == nil {
u = parsedURL
}
}
u2.Path = u.Path
u2.RawPath = u.RawPath
u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
outReq.SetHost(u2.Host)
outReq.Header.SetHost(u2.Host)
if p.passHostHeader {
outReq.Header.SetHost(req.Host)
}
outReq.SetRequestURI(u2.RequestURI())
outReq.SetBodyStream(req.Body, int(req.ContentLength))
outReq.Header.SetMethod(req.Method)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := req.Header["X-Forwarded-For"]
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
omit := ok && prior == nil // Go Issue 38079: nil now means don't populate the header
if !omit {
outReq.Header.Set("X-Forwarded-For", clientIP)
}
}
if err := p.roundTrip(rw, req, outReq, reqUpType); err != nil {
proxyhttputil.ErrorHandler(rw, req, err)
}
}
// Note that unlike the net/http RoundTrip:
// - we are not supporting "100 Continue" response to forward them as-is to the client.
// - we are not asking for compressed response automatically. That is because this will add an extra cost when the
// client is asking for an uncompressed response, as we will have to un-compress it, and nowadays most clients are
// already asking for compressed response (allowing "passthrough" compression).
func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outReq *fasthttp.Request, reqUpType string) error {
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
var co *conn
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var err error
co, err = p.connPool.AcquireConn()
if err != nil {
return fmt.Errorf("acquire connection: %w", err)
}
wd := &writeDetector{Conn: co}
err = p.writeRequest(wd, outReq)
if wd.written && trace != nil && trace.WroteRequest != nil {
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
trace.WroteRequest(httptrace.WroteRequestInfo{})
}
if err == nil {
break
}
log.Ctx(ctx).Debug().Err(err).Msg("Error while writing request")
co.Close()
if wd.written && !isReplayable(req) {
return err
}
}
br := p.readerPool.Get()
if br == nil {
br = bufio.NewReaderSize(co, bufioSize)
}
defer p.readerPool.Put(br)
br.Reset(co)
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.Header.SetNoDefaultContentType(true)
for {
var timer *time.Timer
errTimeout := atomic.Pointer[timeoutError]{}
if p.responseHeaderTimeout > 0 {
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
co.Close()
})
}
res.Header.SetNoDefaultContentType(true)
if err := res.Header.Read(br); err != nil {
if p.responseHeaderTimeout > 0 {
if errT := errTimeout.Load(); errT != nil {
return errT
}
}
co.Close()
return err
}
if timer != nil {
timer.Stop()
}
fixPragmaCacheControl(&res.Header)
resCode := res.StatusCode()
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
if is1xxNonTerminal {
removeConnectionHeaders(&res.Header)
h := rw.Header()
for _, header := range hopHeaders {
res.Header.Del(header)
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
res.Reset()
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
continue
}
break
}
announcedTrailers := res.Header.Peek("Trailer")
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode() == http.StatusSwitchingProtocols {
// As the connection has been hijacked, it cannot be added back to the pool.
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
return nil
}
removeConnectionHeaders(&res.Header)
for _, header := range hopHeaders {
res.Header.Del(header)
}
if len(announcedTrailers) > 0 {
res.Header.Add("Trailer", string(announcedTrailers))
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
if res.Header.ContentLength() == -1 {
cbr := httputil.NewChunkedReader(br)
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
co.Close()
return err
}
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
if err := res.Header.ReadTrailer(br); err != nil {
co.Close()
return err
}
if res.Header.Len() > 0 {
var announcedTrailersKey []string
if len(announcedTrailers) > 0 {
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
}
res.Header.VisitAll(func(key, value []byte) {
for _, s := range announcedTrailersKey {
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
rw.Header().Add(string(key), string(value))
return
}
}
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
})
}
p.connPool.ReleaseConn(co)
return nil
}
brl := p.limitReaderPool.Get()
if brl == nil {
brl = &io.LimitedReader{}
}
defer p.limitReaderPool.Put(brl)
brl.R = br
brl.N = int64(res.Header.ContentLength())
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
co.Close()
return err
}
p.connPool.ReleaseConn(co)
return nil
}
func (p *ReverseProxy) writeRequest(co net.Conn, outReq *fasthttp.Request) error {
bw := p.writerPool.Get()
if bw == nil {
bw = bufio.NewWriterSize(co, bufioSize)
}
defer p.writerPool.Put(bw)
bw.Reset(co)
if err := outReq.Write(bw); err != nil {
return err
}
return bw.Flush()
}
// isReplayable returns whether the request is replayable.
func isReplayable(req *http.Request) bool {
if req.Body == nil || req.Body == http.NoBody {
switch req.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
}
// The Idempotency-Key, while non-standard, is widely used to
// mean a POST or other request is idempotent. See
// https://golang.org/issue/19943#issuecomment-421092421
if _, ok := req.Header["Idempotency-Key"]; ok {
return true
}
if _, ok := req.Header["X-Idempotency-Key"]; ok {
return true
}
}
return false
}
// isGraphic returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func isGraphic(s string) bool {
for i := range len(s) {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
type fasthttpHeader interface {
Peek(key string) []byte
Set(key string, value string)
SetBytesV(key string, value []byte)
DelBytes(key []byte)
Del(key string)
}
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
// See RFC 7230, section 6.1.
func removeConnectionHeaders(h fasthttpHeader) {
f := h.Peek(fasthttp.HeaderConnection)
for _, sf := range bytes.Split(f, []byte{','}) {
if sf = bytes.TrimSpace(sf); len(sf) > 0 {
h.DelBytes(sf)
}
}
}
// RFC 7234, section 5.4: Should treat Pragma: no-cache like Cache-Control: no-cache.
func fixPragmaCacheControl(header fasthttpHeader) {
if pragma := header.Peek("Pragma"); bytes.Equal(pragma, []byte("no-cache")) {
if len(header.Peek("Cache-Control")) == 0 {
header.Set("Cache-Control", "no-cache")
}
}
}
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20
func cleanWebSocketHeaders(headers fasthttpHeader) {
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key"))
headers.Del("Sec-Websocket-Key")
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions"))
headers.Del("Sec-Websocket-Extensions")
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept"))
headers.Del("Sec-Websocket-Accept")
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol"))
headers.Del("Sec-Websocket-Protocol")
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version"))
headers.DelBytes([]byte("Sec-Websocket-Version"))
}

View file

@ -0,0 +1,311 @@
package fast
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
"time"
"github.com/armon/go-socks5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/static"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/traefik/traefik/v3/pkg/tls/generate"
)
const (
proxyHTTP = "http"
proxyHTTPS = "https"
proxySocks5 = "socks"
)
type authCreds struct {
user string
password string
}
func TestProxyFromEnvironment(t *testing.T) {
testCases := []struct {
desc string
proxyType string
tls bool
auth *authCreds
}{
{
desc: "Proxy HTTP with HTTP Backend",
proxyType: proxyHTTP,
},
{
desc: "Proxy HTTP with HTTP backend and proxy auth",
proxyType: proxyHTTP,
tls: false,
auth: &authCreds{
user: "user",
password: "password",
},
},
{
desc: "Proxy HTTP with HTTPS backend",
proxyType: proxyHTTP,
tls: true,
},
{
desc: "Proxy HTTP with HTTPS backend and proxy auth",
proxyType: proxyHTTP,
tls: true,
auth: &authCreds{
user: "user",
password: "password",
},
},
{
desc: "Proxy HTTPS with HTTP backend",
proxyType: proxyHTTPS,
},
{
desc: "Proxy HTTPS with HTTP backend and proxy auth",
proxyType: proxyHTTPS,
tls: false,
auth: &authCreds{
user: "user",
password: "password",
},
},
{
desc: "Proxy HTTPS with HTTPS backend",
proxyType: proxyHTTPS,
tls: true,
},
{
desc: "Proxy HTTPS with HTTPS backend and proxy auth",
proxyType: proxyHTTPS,
tls: true,
auth: &authCreds{
user: "user",
password: "password",
},
},
{
desc: "Proxy Socks5 with HTTP backend",
proxyType: proxySocks5,
},
{
desc: "Proxy Socks5 with HTTP backend and proxy auth",
proxyType: proxySocks5,
auth: &authCreds{
user: "user",
password: "password",
},
},
{
desc: "Proxy Socks5 with HTTPS backend",
proxyType: proxySocks5,
tls: true,
},
{
desc: "Proxy Socks5 with HTTPS backend and proxy auth",
proxyType: proxySocks5,
tls: true,
auth: &authCreds{
user: "user",
password: "password",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
backendURL, backendCert := newBackendServer(t, test.tls, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte("backend"))
}))
var proxyCalled bool
proxyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
proxyCalled = true
if test.auth != nil {
proxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.auth.user+":"+test.auth.password))
require.Equal(t, proxyAuth, req.Header.Get("Proxy-Authorization"))
}
if req.Method != http.MethodConnect {
proxy := httputil.NewSingleHostReverseProxy(testhelpers.MustParseURL("http://" + req.Host))
proxy.ServeHTTP(rw, req)
return
}
// CONNECT method
conn, err := net.Dial("tcp", req.Host)
require.NoError(t, err)
hj, ok := rw.(http.Hijacker)
require.True(t, ok)
rw.WriteHeader(http.StatusOK)
connHj, _, err := hj.Hijack()
require.NoError(t, err)
go func() { _, _ = io.Copy(connHj, conn) }()
_, _ = io.Copy(conn, connHj)
})
var proxyURL string
var proxyCert *x509.Certificate
switch test.proxyType {
case proxySocks5:
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxyURL = fmt.Sprintf("socks5://%s", ln.Addr())
go func() {
conn, err := ln.Accept()
require.NoError(t, err)
proxyCalled = true
conf := &socks5.Config{}
if test.auth != nil {
conf.Credentials = socks5.StaticCredentials{test.auth.user: test.auth.password}
}
server, err := socks5.New(conf)
require.NoError(t, err)
// We are not checking the error, because ServeConn is blocked until the client or the backend
// connection is closed which, in some cases, raises a connection reset by peer error.
_ = server.ServeConn(conn)
err = ln.Close()
require.NoError(t, err)
}()
case proxyHTTP:
proxyServer := httptest.NewServer(proxyHandler)
t.Cleanup(proxyServer.Close)
proxyURL = proxyServer.URL
case proxyHTTPS:
proxyServer := httptest.NewServer(proxyHandler)
t.Cleanup(proxyServer.Close)
proxyURL = proxyServer.URL
proxyCert = proxyServer.Certificate()
}
certPool := x509.NewCertPool()
if proxyCert != nil {
certPool.AddCert(proxyCert)
}
if backendCert != nil {
cert, err := x509.ParseCertificate(backendCert.Certificate[0])
require.NoError(t, err)
certPool.AddCert(cert)
}
builder := NewProxyBuilder(&transportManagerMock{tlsConfig: &tls.Config{RootCAs: certPool}}, static.FastProxyConfig{})
builder.proxy = func(req *http.Request) (*url.URL, error) {
u, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
if test.auth != nil {
u.User = url.UserPassword(test.auth.user, test.auth.password)
}
return u, nil
}
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false)
require.NoError(t, err)
reverseProxyServer := httptest.NewServer(reverseProxy)
t.Cleanup(reverseProxyServer.Close)
client := http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(reverseProxyServer.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "backend", string(body))
assert.True(t, proxyCalled)
})
}
}
func newCertificate(t *testing.T, domain string) *tls.Certificate {
t.Helper()
certPEM, keyPEM, err := generate.KeyPair(domain, time.Time{})
require.NoError(t, err)
certificate, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
return &certificate
}
func newBackendServer(t *testing.T, isTLS bool, handler http.Handler) (string, *tls.Certificate) {
t.Helper()
var ln net.Listener
var err error
var cert *tls.Certificate
scheme := "http"
domain := "backend.localhost"
if isTLS {
scheme = "https"
cert = newCertificate(t, domain)
ln, err = tls.Listen("tcp", ":0", &tls.Config{Certificates: []tls.Certificate{*cert}})
require.NoError(t, err)
} else {
ln, err = net.Listen("tcp", ":0")
require.NoError(t, err)
}
srv := &http.Server{Handler: handler}
go func() { _ = srv.Serve(ln) }()
t.Cleanup(func() { _ = srv.Close() })
_, port, err := net.SplitHostPort(ln.Addr().String())
require.NoError(t, err)
backendURL := fmt.Sprintf("%s://%s:%s", scheme, domain, port)
return backendURL, cert
}
type transportManagerMock struct {
tlsConfig *tls.Config
}
func (r *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
return r.tlsConfig, nil
}
func (r *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
return &dynamic.ServersTransport{}, nil
}

View file

@ -0,0 +1,693 @@
package fast
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
gorillawebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"golang.org/x/net/websocket"
)
func TestWebSocketTCPClose(t *testing.T) {
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
_, _, err := c.ReadMessage()
if err != nil {
errChan <- err
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
proxyAddr := proxy.Listener.Addr().String()
_, conn, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
).open()
require.NoError(t, err)
conn.Close()
serverErr := <-errChan
var wsErr *gorillawebsocket.CloseError
require.ErrorAs(t, serverErr, &wsErr)
assert.Equal(t, 1006, wsErr.Code)
}
func TestWebSocketPingPong(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
CheckOrigin: func(*http.Request) bool {
return true
},
}
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
ws, err := upgrader.Upgrade(writer, request, nil)
require.NoError(t, err)
ws.SetPingHandler(func(appData string) error {
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
require.NoError(t, err)
return nil
})
_, _, _ = ws.ReadMessage()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
defer conn.Close()
goodErr := fmt.Errorf("signal: %s", "Good data")
badErr := fmt.Errorf("signal: %s", "Bad data")
conn.SetPongHandler(func(data string) error {
if data == "PingPong" {
return goodErr
}
return badErr
})
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
require.NoError(t, err)
_, _, err = conn.ReadMessage()
if !errors.Is(err, goodErr) {
require.NoError(t, err)
}
}
func TestWebSocketEcho(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
n, err := conn.Read(msg)
require.NoError(t, err)
_, err = conn.Write(msg[:n])
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
require.NoError(t, err)
assert.Equal(t, "OK", string(msg))
err = conn.Close()
require.NoError(t, err)
}
func TestWebSocketPassHost(t *testing.T) {
testCases := []struct {
desc string
passHost bool
expected string
}{
{
desc: "PassHost false",
passHost: false,
},
{
desc: "PassHost true",
passHost: true,
expected: "example.com",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
req := conn.Request()
if test.passHost {
require.Equal(t, test.expected, req.Host)
} else {
require.NotEqual(t, test.expected, req.Host)
}
msg := make([]byte, 4)
n, err := conn.Read(msg)
require.NoError(t, err)
_, err = conn.Write(msg[:n])
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
headers.Add("Host", "example.com")
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
require.NoError(t, err)
assert.Equal(t, "OK", string(msg))
err = conn.Close()
require.NoError(t, err)
})
}
}
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
withOrigin("http://127.0.0.2"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithOrigin(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
_, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
withOrigin("http://127.0.0.2"),
).send()
require.EqualError(t, err, "bad status")
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithQueryParams(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "test", r.URL.Query().Get("query"))
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws?query=test"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_ = conn.Close()
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
u := parseURI(t, srv.URL)
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
return net.Dial("tcp", u.Host)
}))
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
defer conn.Close()
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
}
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/%3A%2F%2F"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketUpgradeFailed(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusBadRequest)
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
u := parseURI(t, srv.URL)
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
return net.Dial("tcp", u.Host)
}))
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
if path != "/ws" {
w.WriteHeader(http.StatusOK)
return
}
// Set new backend URL
req.URL = parseURI(t, srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
}))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
require.NoError(t, err)
defer conn.Close()
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
require.NoError(t, err)
req.Header.Add("upgrade", "websocket")
req.Header.Add("Connection", "upgrade")
err = req.Write(conn)
require.NoError(t, err)
// First request works with 400
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
}
func TestForwardsWebsocketTraffic(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_, err := conn.Write([]byte("ok"))
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
defer proxyWithoutTLSConfig.Close()
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
_, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.EqualError(t, err, "bad status")
pool := createConnectionPool(srv.URL, &tls.Config{InsecureSkipVerify: true})
proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, pool)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
const dialTimeout = time.Second
type websocketRequestOpt func(w *websocketRequest)
func withServer(server string) websocketRequestOpt {
return func(w *websocketRequest) {
w.ServerAddr = server
}
}
func withPath(path string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Path = path
}
}
func withData(data string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Data = data
}
}
func withOrigin(origin string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Origin = origin
}
}
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
wsrequest := &websocketRequest{}
for _, opt := range opts {
opt(wsrequest)
}
if wsrequest.Origin == "" {
wsrequest.Origin = "http://" + wsrequest.ServerAddr
}
if wsrequest.Config == nil {
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
}
return wsrequest
}
type websocketRequest struct {
ServerAddr string
Path string
Data string
Origin string
Config *websocket.Config
}
func (w *websocketRequest) send() (string, error) {
conn, _, err := w.open()
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.Write([]byte(w.Data)); err != nil {
return "", err
}
msg := make([]byte, 512)
var n int
n, err = conn.Read(msg)
if err != nil {
return "", err
}
received := string(msg[:n])
return received, nil
}
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(w.Config, client)
if err != nil {
return nil, nil, err
}
return conn, client, err
}
func parseURI(t *testing.T, uri string) *url.URL {
t.Helper()
out, err := url.ParseRequestURI(uri)
require.NoError(t, err)
return out
}
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
u := testhelpers.MustParseURL(target)
return newConnPool(200, 0, func() (net.Conn, error) {
if tlsConfig != nil {
return tls.Dial("tcp", u.Host, tlsConfig)
}
return net.Dial("tcp", u.Host)
})
}
func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptest.Server {
t.Helper()
u := parseURI(t, uri)
proxy, err := NewReverseProxy(u, nil, false, true, 0, pool)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = u
req.URL.Path = path
proxy.ServeHTTP(w, req)
}))
t.Cleanup(srv.Close)
return srv
}

104
pkg/proxy/fast/upgrade.go Normal file
View file

@ -0,0 +1,104 @@
package fast
import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/valyala/fasthttp"
"golang.org/x/net/http/httpguts"
)
// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
user, backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
_, err := io.Copy(c.user, c.backend)
errc <- err
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.Copy(c.backend, c.user)
errc <- err
}
func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) {
defer backConn.Close()
resUpType := upgradeTypeFastHTTP(&res.Header)
if !strings.EqualFold(reqUpType, resUpType) {
httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
return
}
hj, ok := rw.(http.Hijacker)
if !ok {
httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
_ = backConn.Close()
}()
defer close(backConnCloseCh)
conn, brw, err := hj.Hijack()
if err != nil {
httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err))
return
}
defer conn.Close()
for k, values := range rw.Header() {
for _, v := range values {
res.Header.Add(k, v)
}
}
if err := res.Header.Write(brw.Writer); err != nil {
httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err))
return
}
if err := brw.Flush(); err != nil {
httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err))
return
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
<-errc
}
func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
func upgradeTypeFastHTTP(h fasthttpHeader) string {
if !bytes.Contains(h.Peek("Connection"), []byte("Upgrade")) {
return ""
}
return string(h.Peek("Upgrade"))
}

View file

@ -1,23 +1,25 @@
package service package httputil
import "sync" import "sync"
const bufferPoolSize = 32 * 1024 const bufferSize = 32 * 1024
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() interface{} {
return make([]byte, bufferPoolSize)
},
},
}
}
type bufferPool struct { type bufferPool struct {
pool sync.Pool pool sync.Pool
} }
func newBufferPool() *bufferPool {
b := &bufferPool{
pool: sync.Pool{},
}
b.pool.New = func() interface{} {
return make([]byte, bufferSize)
}
return b
}
func (b *bufferPool) Get() []byte { func (b *bufferPool) Get() []byte {
return b.pool.Get().([]byte) return b.pool.Get().([]byte)
} }

View file

@ -0,0 +1,54 @@
package httputil
import (
"crypto/tls"
"fmt"
"net/http"
"net/url"
"time"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/metrics"
)
// TransportManager manages transport used for backend communications.
type TransportManager interface {
Get(name string) (*dynamic.ServersTransport, error)
GetRoundTripper(name string) (http.RoundTripper, error)
GetTLSConfig(name string) (*tls.Config, error)
}
// ProxyBuilder handles the http.RoundTripper for httputil reverse proxies.
type ProxyBuilder struct {
bufferPool *bufferPool
transportManager TransportManager
semConvMetricsRegistry *metrics.SemConvMetricsRegistry
}
// NewProxyBuilder creates a new ProxyBuilder.
func NewProxyBuilder(transportManager TransportManager, semConvMetricsRegistry *metrics.SemConvMetricsRegistry) *ProxyBuilder {
return &ProxyBuilder{
bufferPool: newBufferPool(),
transportManager: transportManager,
semConvMetricsRegistry: semConvMetricsRegistry,
}
}
// Update does nothing.
func (r *ProxyBuilder) Update(_ map[string]*dynamic.ServersTransport) {}
// Build builds a new httputil.ReverseProxy with the given configuration.
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
roundTripper, err := r.transportManager.GetRoundTripper(cfgName)
if err != nil {
return nil, fmt.Errorf("getting RoundTripper: %w", err)
}
if shouldObserve {
// Wrapping the roundTripper with the Tracing roundTripper,
// to handle the reverseProxy client span creation.
roundTripper = newObservabilityRoundTripper(r.semConvMetricsRegistry, roundTripper)
}
return buildSingleHostProxy(targetURL, passHostHeader, flushInterval, roundTripper, r.bufferPool), nil
}

View file

@ -0,0 +1,56 @@
package httputil
import (
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/testhelpers"
)
func TestEscapedPath(t *testing.T) {
var gotEscapedPath string
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
gotEscapedPath = req.URL.EscapedPath()
}))
transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{"default": &http.Transport{}},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(p.ServeHTTP))
_, err = http.Get(proxy.URL + "/%3A%2F%2F")
require.NoError(t, err)
assert.Equal(t, "/%3A%2F%2F", gotEscapedPath)
}
type transportManagerMock struct {
roundTrippers map[string]http.RoundTripper
}
func (t *transportManagerMock) GetRoundTripper(name string) (http.RoundTripper, error) {
roundTripper, ok := t.roundTrippers[name]
if !ok {
return nil, errors.New("no transport for " + name)
}
return roundTripper, nil
}
func (t *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
panic("implement me")
}
func (t *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
panic("implement me")
}

View file

@ -1,4 +1,4 @@
package service package httputil
import ( import (
"context" "context"
@ -23,6 +23,13 @@ type wrapper struct {
rt http.RoundTripper rt http.RoundTripper
} }
func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper {
return &wrapper{
semConvMetricRegistry: semConvMetricRegistry,
rt: rt,
}
}
func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) { func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) {
start := time.Now() start := time.Now()
var span trace.Span var span trace.Span
@ -42,7 +49,7 @@ func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) {
var headers http.Header var headers http.Header
response, err := t.rt.RoundTrip(req) response, err := t.rt.RoundTrip(req)
if err != nil { if err != nil {
statusCode = computeStatusCode(err) statusCode = ComputeStatusCode(err)
} }
if response != nil { if response != nil {
statusCode = response.StatusCode statusCode = response.StatusCode
@ -96,10 +103,3 @@ func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) {
return response, err return response, err
} }
func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper {
return &wrapper{
semConvMetricRegistry: semConvMetricRegistry,
rt: rt,
}
}

View file

@ -1,4 +1,4 @@
package service package httputil
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package service package httputil
import ( import (
"context" "context"
@ -27,7 +27,7 @@ func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval ti
Transport: roundTripper, Transport: roundTripper,
FlushInterval: flushInterval, FlushInterval: flushInterval,
BufferPool: bufferPool, BufferPool: bufferPool,
ErrorHandler: errorHandler, ErrorHandler: ErrorHandler,
} }
} }
@ -93,8 +93,9 @@ func isWebSocketUpgrade(req *http.Request) bool {
strings.EqualFold(req.Header.Get("Upgrade"), "websocket") strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
} }
func errorHandler(w http.ResponseWriter, req *http.Request, err error) { // ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
statusCode := computeStatusCode(err) func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
statusCode := ComputeStatusCode(err)
logger := log.Ctx(req.Context()) logger := log.Ctx(req.Context())
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode)) logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
@ -105,7 +106,8 @@ func errorHandler(w http.ResponseWriter, req *http.Request, err error) {
} }
} }
func computeStatusCode(err error) int { // ComputeStatusCode computes the HTTP status code according to the given error.
func ComputeStatusCode(err error) int {
switch { switch {
case errors.Is(err, io.EOF): case errors.Is(err, io.EOF):
return http.StatusBadGateway return http.StatusBadGateway

View file

@ -1,4 +1,4 @@
package service package httputil
import ( import (
"bufio" "bufio"
@ -8,13 +8,13 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
gorillawebsocket "github.com/gorilla/websocket" gorillawebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -27,6 +27,7 @@ func TestWebSocketTCPClose(t *testing.T) {
return return
} }
defer c.Close() defer c.Close()
for { for {
_, _, err := c.ReadMessage() _, _, err := c.ReadMessage()
if err != nil { if err != nil {
@ -71,6 +72,7 @@ func TestWebSocketPingPong(t *testing.T) {
ws.SetPingHandler(func(appData string) error { ws.SetPingHandler(func(appData string) error {
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong")) err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
require.NoError(t, err) require.NoError(t, err)
return nil return nil
}) })
@ -97,6 +99,7 @@ func TestWebSocketPingPong(t *testing.T) {
if data == "PingPong" { if data == "PingPong" {
return goodErr return goodErr
} }
return badErr return badErr
}) })
@ -104,7 +107,6 @@ func TestWebSocketPingPong(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, _, err = conn.ReadMessage() _, _, err = conn.ReadMessage()
if !errors.Is(err, goodErr) { if !errors.Is(err, goodErr) {
require.NoError(t, err) require.NoError(t, err)
} }
@ -114,12 +116,10 @@ func TestWebSocketEcho(t *testing.T) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4) msg := make([]byte, 4)
_, err := conn.Read(msg) n, err := conn.Read(msg)
require.NoError(t, err) require.NoError(t, err)
fmt.Println(string(msg)) _, err = conn.Write(msg[:n])
_, err = conn.Write(msg)
require.NoError(t, err) require.NoError(t, err)
err = conn.Close() err = conn.Close()
@ -142,7 +142,10 @@ func TestWebSocketEcho(t *testing.T) {
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err) require.NoError(t, err)
fmt.Println(conn.ReadMessage()) _, msg, err := conn.ReadMessage()
require.NoError(t, err)
assert.Equal(t, "OK", string(msg))
err = conn.Close() err = conn.Close()
require.NoError(t, err) require.NoError(t, err)
@ -178,11 +181,10 @@ func TestWebSocketPassHost(t *testing.T) {
} }
msg := make([]byte, 4) msg := make([]byte, 4)
_, err := conn.Read(msg) n, err := conn.Read(msg)
require.NoError(t, err) require.NoError(t, err)
fmt.Println(string(msg)) _, err = conn.Write(msg[:n])
_, err = conn.Write(msg)
require.NoError(t, err) require.NoError(t, err)
err = conn.Close() err = conn.Close()
@ -207,7 +209,10 @@ func TestWebSocketPassHost(t *testing.T) {
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err) require.NoError(t, err)
fmt.Println(conn.ReadMessage()) _, msg, err := conn.ReadMessage()
require.NoError(t, err)
assert.Equal(t, "OK", string(msg))
err = conn.Close() err = conn.Close()
require.NoError(t, err) require.NoError(t, err)
@ -216,27 +221,8 @@ func TestWebSocketPassHost(t *testing.T) {
} }
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
return true srv := createServer(t, upgrader, func(*http.Request) {})
}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close() defer proxy.Close()
@ -254,25 +240,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
} }
func TestWebSocketRequestWithOrigin(t *testing.T) { func TestWebSocketRequestWithOrigin(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{} srv := createServer(t, gorillawebsocket.Upgrader{}, func(*http.Request) {})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close() defer proxy.Close()
@ -297,26 +265,9 @@ func TestWebSocketRequestWithOrigin(t *testing.T) {
} }
func TestWebSocketRequestWithQueryParams(t *testing.T) { func TestWebSocketRequestWithQueryParams(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{} srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "test", r.URL.Query().Get("query")) assert.Equal(t, "test", r.URL.Query().Get("query"))
for { })
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close() defer proxy.Close()
@ -341,11 +292,19 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
srv := httptest.NewServer(mux) srv := httptest.NewServer(mux)
defer srv.Close() defer srv.Close()
f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{
"default@internal": &http.Transport{},
},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL) req.URL = testhelpers.MustParseURL(srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE") w.Header().Set("HEADER-KEY", "HEADER-VALUE")
f.ServeHTTP(w, req) p.ServeHTTP(w, req)
})) }))
defer proxy.Close() defer proxy.Close()
@ -363,26 +322,9 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
} }
func TestWebSocketRequestWithEncodedChar(t *testing.T) { func TestWebSocketRequestWithEncodedChar(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{} srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath()) assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
for { })
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport) proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close() defer proxy.Close()
@ -407,15 +349,23 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
srv := httptest.NewServer(mux) srv := httptest.NewServer(mux)
defer srv.Close() defer srv.Close()
f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil) transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{
"default@internal": &http.Transport{},
},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path path := req.URL.Path // keep the original path
if path == "/ws" { if path == "/ws" {
// Set new backend URL // Set new backend URL
req.URL = parseURI(t, srv.URL) req.URL = testhelpers.MustParseURL(srv.URL)
req.URL.Path = path req.URL.Path = path
f.ServeHTTP(w, req) p.ServeHTTP(w, req)
} else { } else {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
@ -629,27 +579,60 @@ func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
return conn, client, err return conn, client, err
} }
func parseURI(t *testing.T, uri string) *url.URL {
t.Helper()
out, err := url.ParseRequestURI(uri)
require.NoError(t, err)
return out
}
func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server { func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server {
t.Helper() t.Helper()
u := parseURI(t, uri) u := testhelpers.MustParseURL(uri)
proxy := buildSingleHostProxy(u, true, 0, transport, nil)
transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{"fwd": transport},
}
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, 0)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path // keep the original path
path := req.URL.Path
// Set new backend URL // Set new backend URL
req.URL = u req.URL = u
req.URL.Path = path req.URL.Path = path
proxy.ServeHTTP(w, req) p.ServeHTTP(w, req)
})) }))
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
return srv
}
func createServer(t *testing.T, upgrader gorillawebsocket.Upgrader, check func(*http.Request)) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("Error during upgrade: %v", err)
return
}
defer conn.Close()
check(r)
for {
mt, message, err := conn.ReadMessage()
if err != nil {
t.Logf("Error during read: %v", err)
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
t.Logf("Error during write: %v", err)
break
}
}
}))
t.Cleanup(srv.Close)
return srv return srv
} }

View file

@ -0,0 +1,61 @@
package proxy
import (
"crypto/tls"
"fmt"
"net/http"
"net/url"
"time"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/static"
"github.com/traefik/traefik/v3/pkg/proxy/fast"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/server/service"
)
// TransportManager manages transport used for backend communications.
type TransportManager interface {
Get(name string) (*dynamic.ServersTransport, error)
GetRoundTripper(name string) (http.RoundTripper, error)
GetTLSConfig(name string) (*tls.Config, error)
}
// SmartBuilder is a proxy builder which returns a fast proxy or httputil proxy corresponding
// to the ServersTransport configuration.
type SmartBuilder struct {
fastProxyBuilder *fast.ProxyBuilder
proxyBuilder service.ProxyBuilder
transportManager httputil.TransportManager
}
// NewSmartBuilder creates and returns a new SmartBuilder instance.
func NewSmartBuilder(transportManager TransportManager, proxyBuilder service.ProxyBuilder, fastProxyConfig static.FastProxyConfig) *SmartBuilder {
return &SmartBuilder{
fastProxyBuilder: fast.NewProxyBuilder(transportManager, fastProxyConfig),
proxyBuilder: proxyBuilder,
transportManager: transportManager,
}
}
// Update is the handler called when the dynamic configuration is updated.
func (b *SmartBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
b.fastProxyBuilder.Update(newConfigs)
}
// Build builds an HTTP proxy for the given URL using the ServersTransport with the given name.
func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
serversTransport, err := b.transportManager.Get(configName)
if err != nil {
return nil, fmt.Errorf("getting ServersTransport: %w", err)
}
// The fast proxy implementation cannot handle HTTP/2 requests for now.
// For the https scheme we cannot guess if the backend communication will use HTTP2,
// thus we check if HTTP/2 is disabled to use the fast proxy implementation when this is possible.
if targetURL.Scheme == "h2c" || (targetURL.Scheme == "https" && !serversTransport.DisableHTTP2) {
return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, flushInterval)
}
return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader)
}

View file

@ -0,0 +1,113 @@
package proxy
import (
"encoding/pem"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/static"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/server/service"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/traefik/traefik/v3/pkg/types"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
func TestSmartBuilder_Build(t *testing.T) {
tests := []struct {
desc string
serversTransport dynamic.ServersTransport
fastProxyConfig static.FastProxyConfig
https bool
h2c bool
wantFastProxy bool
}{
{
desc: "fastproxy",
fastProxyConfig: static.FastProxyConfig{Debug: true},
wantFastProxy: true,
},
{
desc: "fastproxy with https and without DisableHTTP2",
https: true,
fastProxyConfig: static.FastProxyConfig{Debug: true},
wantFastProxy: false,
},
{
desc: "fastproxy with https and DisableHTTP2",
https: true,
serversTransport: dynamic.ServersTransport{DisableHTTP2: true},
fastProxyConfig: static.FastProxyConfig{Debug: true},
wantFastProxy: true,
},
{
desc: "fastproxy with h2c",
h2c: true,
fastProxyConfig: static.FastProxyConfig{Debug: true},
wantFastProxy: false,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var callCount int
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if test.wantFastProxy {
assert.Contains(t, r.Header, "X-Traefik-Fast-Proxy")
} else {
assert.NotContains(t, r.Header, "X-Traefik-Fast-Proxy")
}
})
var server *httptest.Server
if test.https {
server = httptest.NewUnstartedServer(handler)
server.EnableHTTP2 = false
server.StartTLS()
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[0].Certificate[0]})
test.serversTransport.RootCAs = []types.FileOrContent{
types.FileOrContent(certPEM),
}
} else {
server = httptest.NewServer(h2c.NewHandler(handler, &http2.Server{}))
}
t.Cleanup(func() {
server.Close()
})
targetURL := testhelpers.MustParseURL(server.URL)
if test.h2c {
targetURL.Scheme = "h2c"
}
serversTransports := map[string]*dynamic.ServersTransport{
"test": &test.serversTransport,
}
transportManager := service.NewTransportManager(nil)
transportManager.Update(serversTransports)
httpProxyBuilder := httputil.NewProxyBuilder(transportManager, nil)
proxyBuilder := NewSmartBuilder(transportManager, httpProxyBuilder, test.fastProxyConfig)
proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, time.Second)
require.NoError(t, err)
rw := httptest.NewRecorder()
proxyHandler.ServeHTTP(rw, httptest.NewRequest(http.MethodGet, "/", http.NoBody))
assert.Equal(t, 1, callCount)
})
}
}

View file

@ -2,10 +2,12 @@ package router
import ( import (
"context" "context"
"crypto/tls"
"io" "io"
"math" "math"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -18,7 +20,7 @@ import (
"github.com/traefik/traefik/v3/pkg/server/middleware" "github.com/traefik/traefik/v3/pkg/server/middleware"
"github.com/traefik/traefik/v3/pkg/server/service" "github.com/traefik/traefik/v3/pkg/server/service"
"github.com/traefik/traefik/v3/pkg/testhelpers" "github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/traefik/traefik/v3/pkg/tls" traefiktls "github.com/traefik/traefik/v3/pkg/tls"
) )
func TestRouterManager_Get(t *testing.T) { func TestRouterManager_Get(t *testing.T) {
@ -309,11 +311,12 @@ func TestRouterManager_Get(t *testing.T) {
}, },
}) })
roundTripperManager := service.NewRoundTripperManager(nil) transportManager := service.NewTransportManager(nil)
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager)
serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, proxyBuilderMock{})
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
tlsManager := tls.NewManager() tlsManager := traefiktls.NewManager()
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
@ -340,7 +343,7 @@ func TestRuntimeConfiguration(t *testing.T) {
serviceConfig map[string]*dynamic.Service serviceConfig map[string]*dynamic.Service
routerConfig map[string]*dynamic.Router routerConfig map[string]*dynamic.Router
middlewareConfig map[string]*dynamic.Middleware middlewareConfig map[string]*dynamic.Middleware
tlsOptions map[string]tls.Options tlsOptions map[string]traefiktls.Options
expectedError int expectedError int
}{ }{
{ {
@ -597,7 +600,7 @@ func TestRuntimeConfiguration(t *testing.T) {
TLS: &dynamic.RouterTLSConfig{}, TLS: &dynamic.RouterTLSConfig{},
}, },
}, },
tlsOptions: map[string]tls.Options{}, tlsOptions: map[string]traefiktls.Options{},
expectedError: 1, expectedError: 1,
}, },
{ {
@ -624,9 +627,9 @@ func TestRuntimeConfiguration(t *testing.T) {
}, },
}, },
}, },
tlsOptions: map[string]tls.Options{ tlsOptions: map[string]traefiktls.Options{
"broken-tlsOption": { "broken-tlsOption": {
ClientAuth: tls.ClientAuth{ ClientAuth: traefiktls.ClientAuth{
ClientAuthType: "foobar", ClientAuthType: "foobar",
}, },
}, },
@ -655,9 +658,9 @@ func TestRuntimeConfiguration(t *testing.T) {
TLS: &dynamic.RouterTLSConfig{}, TLS: &dynamic.RouterTLSConfig{},
}, },
}, },
tlsOptions: map[string]tls.Options{ tlsOptions: map[string]traefiktls.Options{
"default": { "default": {
ClientAuth: tls.ClientAuth{ ClientAuth: traefiktls.ClientAuth{
ClientAuthType: "foobar", ClientAuthType: "foobar",
}, },
}, },
@ -682,11 +685,12 @@ func TestRuntimeConfiguration(t *testing.T) {
}, },
}) })
roundTripperManager := service.NewRoundTripperManager(nil) transportManager := service.NewTransportManager(nil)
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager)
serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, proxyBuilderMock{})
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
tlsManager := tls.NewManager() tlsManager := traefiktls.NewManager()
tlsManager.UpdateConfigs(context.Background(), nil, test.tlsOptions, nil) tlsManager.UpdateConfigs(context.Background(), nil, test.tlsOptions, nil)
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
@ -759,11 +763,12 @@ func TestProviderOnMiddlewares(t *testing.T) {
}, },
}) })
roundTripperManager := service.NewRoundTripperManager(nil) transportManager := service.NewTransportManager(nil)
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager)
serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, nil)
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
tlsManager := tls.NewManager() tlsManager := traefiktls.NewManager()
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
@ -775,14 +780,22 @@ func TestProviderOnMiddlewares(t *testing.T) {
assert.Equal(t, []string{"m1@docker", "m2@docker", "m1@file"}, rtConf.Middlewares["chain@docker"].Chain.Middlewares) assert.Equal(t, []string{"m1@docker", "m2@docker", "m1@file"}, rtConf.Middlewares["chain@docker"].Chain.Middlewares)
} }
type staticRoundTripperGetter struct { type staticTransportManager struct {
res *http.Response res *http.Response
} }
func (s staticRoundTripperGetter) Get(name string) (http.RoundTripper, error) { func (s staticTransportManager) GetRoundTripper(_ string) (http.RoundTripper, error) {
return &staticTransport{res: s.res}, nil return &staticTransport{res: s.res}, nil
} }
func (s staticTransportManager) GetTLSConfig(_ string) (*tls.Config, error) {
panic("implement me")
}
func (s staticTransportManager) Get(_ string) (*dynamic.ServersTransport, error) {
panic("implement me")
}
type staticTransport struct { type staticTransport struct {
res *http.Response res *http.Response
} }
@ -829,9 +842,9 @@ func BenchmarkRouterServe(b *testing.B) {
}, },
}) })
serviceManager := service.NewManager(rtConf.Services, nil, nil, staticRoundTripperGetter{res}) serviceManager := service.NewManager(rtConf.Services, nil, nil, staticTransportManager{res}, nil)
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil) middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
tlsManager := tls.NewManager() tlsManager := traefiktls.NewManager()
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager) routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
@ -871,7 +884,7 @@ func BenchmarkService(b *testing.B) {
}, },
}) })
serviceManager := service.NewManager(rtConf.Services, nil, nil, staticRoundTripperGetter{res}) serviceManager := service.NewManager(rtConf.Services, nil, nil, staticTransportManager{res}, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil) req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
@ -881,3 +894,13 @@ func BenchmarkService(b *testing.B) {
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
} }
} }
type proxyBuilderMock struct{}
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
}
func (p proxyBuilderMock) Update(_ map[string]*dynamic.ServersTransport) {
panic("implement me")
}

View file

@ -3,7 +3,9 @@ package server
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/config/dynamic"
@ -48,9 +50,10 @@ func TestReuseService(t *testing.T) {
), ),
) )
roundTripperManager := service.NewRoundTripperManager(nil) transportManager := service.NewTransportManager(nil)
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil)
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, proxyBuilderMock{}, nil)
tlsManager := tls.NewManager() tlsManager := tls.NewManager()
dialerManager := tcp.NewDialerManager(nil) dialerManager := tcp.NewDialerManager(nil)
@ -184,9 +187,10 @@ func TestServerResponseEmptyBackend(t *testing.T) {
}, },
} }
roundTripperManager := service.NewRoundTripperManager(nil) transportManager := service.NewTransportManager(nil)
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil)
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, proxyBuilderMock{}, nil)
tlsManager := tls.NewManager() tlsManager := tls.NewManager()
dialerManager := tcp.NewDialerManager(nil) dialerManager := tcp.NewDialerManager(nil)
@ -228,9 +232,10 @@ func TestInternalServices(t *testing.T) {
), ),
) )
roundTripperManager := service.NewRoundTripperManager(nil) transportManager := service.NewTransportManager(nil)
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}}) transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil)
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, nil, nil)
tlsManager := tls.NewManager() tlsManager := tls.NewManager()
dialerManager := tcp.NewDialerManager(nil) dialerManager := tcp.NewDialerManager(nil)
@ -246,3 +251,13 @@ func TestInternalServices(t *testing.T) {
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code") assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
} }
type proxyBuilderMock struct{}
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
}
func (p proxyBuilderMock) Update(_ map[string]*dynamic.ServersTransport) {
panic("implement me")
}

View file

@ -17,7 +17,8 @@ import (
type ManagerFactory struct { type ManagerFactory struct {
observabilityMgr *middleware.ObservabilityMgr observabilityMgr *middleware.ObservabilityMgr
roundTripperManager *RoundTripperManager transportManager *TransportManager
proxyBuilder ProxyBuilder
api func(configuration *runtime.Configuration) http.Handler api func(configuration *runtime.Configuration) http.Handler
restHandler http.Handler restHandler http.Handler
@ -30,11 +31,12 @@ type ManagerFactory struct {
} }
// NewManagerFactory creates a new ManagerFactory. // NewManagerFactory creates a new ManagerFactory.
func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *safe.Pool, observabilityMgr *middleware.ObservabilityMgr, roundTripperManager *RoundTripperManager, acmeHTTPHandler http.Handler) *ManagerFactory { func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *safe.Pool, observabilityMgr *middleware.ObservabilityMgr, transportManager *TransportManager, proxyBuilder ProxyBuilder, acmeHTTPHandler http.Handler) *ManagerFactory {
factory := &ManagerFactory{ factory := &ManagerFactory{
observabilityMgr: observabilityMgr, observabilityMgr: observabilityMgr,
routinesPool: routinesPool, routinesPool: routinesPool,
roundTripperManager: roundTripperManager, transportManager: transportManager,
proxyBuilder: proxyBuilder,
acmeHTTPHandler: acmeHTTPHandler, acmeHTTPHandler: acmeHTTPHandler,
} }
@ -73,7 +75,7 @@ func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *s
// Build creates a service manager. // Build creates a service manager.
func (f *ManagerFactory) Build(configuration *runtime.Configuration) *InternalHandlers { func (f *ManagerFactory) Build(configuration *runtime.Configuration) *InternalHandlers {
svcManager := NewManager(configuration.Services, f.observabilityMgr, f.routinesPool, f.roundTripperManager) svcManager := NewManager(configuration.Services, f.observabilityMgr, f.routinesPool, f.transportManager, f.proxyBuilder)
var apiHandler http.Handler var apiHandler http.Handler
if f.api != nil { if f.api != nil {

View file

@ -1,37 +0,0 @@
package service
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/traefik/traefik/v3/pkg/testhelpers"
)
type staticTransport struct {
res *http.Response
}
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return t.res, nil
}
func BenchmarkProxy(b *testing.B) {
res := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
}
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
pool := newBufferPool()
handler := buildSingleHostProxy(req.URL, false, 0, &staticTransport{res}, pool)
b.ReportAllocs()
for range b.N {
handler.ServeHTTP(w, req)
}
}

View file

@ -9,7 +9,6 @@ import (
"hash/fnv" "hash/fnv"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httputil"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@ -25,6 +24,7 @@ import (
"github.com/traefik/traefik/v3/pkg/middlewares/capture" "github.com/traefik/traefik/v3/pkg/middlewares/capture"
metricsMiddle "github.com/traefik/traefik/v3/pkg/middlewares/metrics" metricsMiddle "github.com/traefik/traefik/v3/pkg/middlewares/metrics"
"github.com/traefik/traefik/v3/pkg/middlewares/observability" "github.com/traefik/traefik/v3/pkg/middlewares/observability"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/safe" "github.com/traefik/traefik/v3/pkg/safe"
"github.com/traefik/traefik/v3/pkg/server/cookie" "github.com/traefik/traefik/v3/pkg/server/cookie"
"github.com/traefik/traefik/v3/pkg/server/middleware" "github.com/traefik/traefik/v3/pkg/server/middleware"
@ -40,17 +40,18 @@ const (
defaultMaxBodySize int64 = -1 defaultMaxBodySize int64 = -1
) )
// RoundTripperGetter is a roundtripper getter interface. // ProxyBuilder builds reverse proxy handlers.
type RoundTripperGetter interface { type ProxyBuilder interface {
Get(name string) (http.RoundTripper, error) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error)
Update(configs map[string]*dynamic.ServersTransport)
} }
// Manager The service manager. // Manager The service manager.
type Manager struct { type Manager struct {
routinePool *safe.Pool routinePool *safe.Pool
observabilityMgr *middleware.ObservabilityMgr observabilityMgr *middleware.ObservabilityMgr
bufferPool httputil.BufferPool transportManager httputil.TransportManager
roundTripperManager RoundTripperGetter proxyBuilder ProxyBuilder
services map[string]http.Handler services map[string]http.Handler
configs map[string]*runtime.ServiceInfo configs map[string]*runtime.ServiceInfo
@ -59,12 +60,12 @@ type Manager struct {
} }
// NewManager creates a new Manager. // NewManager creates a new Manager.
func NewManager(configs map[string]*runtime.ServiceInfo, observabilityMgr *middleware.ObservabilityMgr, routinePool *safe.Pool, roundTripperManager RoundTripperGetter) *Manager { func NewManager(configs map[string]*runtime.ServiceInfo, observabilityMgr *middleware.ObservabilityMgr, routinePool *safe.Pool, transportManager httputil.TransportManager, proxyBuilder ProxyBuilder) *Manager {
return &Manager{ return &Manager{
routinePool: routinePool, routinePool: routinePool,
observabilityMgr: observabilityMgr, observabilityMgr: observabilityMgr,
bufferPool: newBufferPool(), transportManager: transportManager,
roundTripperManager: roundTripperManager, proxyBuilder: proxyBuilder,
services: make(map[string]http.Handler), services: make(map[string]http.Handler),
configs: configs, configs: configs,
healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker), healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker),
@ -298,9 +299,9 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
logger.Debug().Msg("Creating load-balancer") logger.Debug().Msg("Creating load-balancer")
// TODO: should we keep this config value as Go is now handling stream response correctly? // TODO: should we keep this config value as Go is now handling stream response correctly?
flushInterval := dynamic.DefaultFlushInterval flushInterval := time.Duration(dynamic.DefaultFlushInterval)
if service.ResponseForwarding != nil { if service.ResponseForwarding != nil {
flushInterval = service.ResponseForwarding.FlushInterval flushInterval = time.Duration(service.ResponseForwarding.FlushInterval)
} }
if len(service.ServersTransport) > 0 { if len(service.ServersTransport) > 0 {
@ -317,11 +318,6 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
passHostHeader = *service.PassHostHeader passHostHeader = *service.PassHostHeader
} }
roundTripper, err := m.roundTripperManager.Get(service.ServersTransport)
if err != nil {
return nil, err
}
lb := wrr.New(service.Sticky, service.HealthCheck != nil) lb := wrr.New(service.Sticky, service.HealthCheck != nil)
healthCheckTargets := make(map[string]*url.URL) healthCheckTargets := make(map[string]*url.URL)
@ -341,14 +337,12 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName) qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName)
if m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName) { shouldObserve := m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName)
// Wrapping the roundTripper with the Tracing roundTripper, proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, flushInterval)
// to handle the reverseProxy client span creation. if err != nil {
roundTripper = newObservabilityRoundTripper(m.observabilityMgr.SemConvMetricsRegistry(), roundTripper) return nil, fmt.Errorf("error building proxy for server URL %s: %w", server.URL, err)
} }
proxy := buildSingleHostProxy(target, passHostHeader, time.Duration(flushInterval), roundTripper, m.bufferPool)
// Prevents from enabling observability for internal resources. // Prevents from enabling observability for internal resources.
if m.observabilityMgr.ShouldAddAccessLogs(qualifiedSvcName) { if m.observabilityMgr.ShouldAddAccessLogs(qualifiedSvcName) {
@ -393,6 +387,11 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
} }
if service.HealthCheck != nil { if service.HealthCheck != nil {
roundTripper, err := m.transportManager.GetRoundTripper(service.ServersTransport)
if err != nil {
return nil, fmt.Errorf("getting RoundTripper: %w", err)
}
m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker( m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker(
ctx, ctx,
m.observabilityMgr.MetricsRegistry(), m.observabilityMgr.MetricsRegistry(),

View file

@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"crypto/tls"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -14,13 +15,14 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/runtime" "github.com/traefik/traefik/v3/pkg/config/runtime"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/traefik/traefik/v3/pkg/server/provider" "github.com/traefik/traefik/v3/pkg/server/provider"
"github.com/traefik/traefik/v3/pkg/testhelpers" "github.com/traefik/traefik/v3/pkg/testhelpers"
) )
func TestGetLoadBalancer(t *testing.T) { func TestGetLoadBalancer(t *testing.T) {
sm := Manager{ sm := Manager{
roundTripperManager: newRtMock(), transportManager: &transportManagerMock{},
} }
testCases := []struct { testCases := []struct {
@ -40,14 +42,14 @@ func TestGetLoadBalancer(t *testing.T) {
}, },
}, },
}, },
fwd: &MockForwarder{}, fwd: &forwarderMock{},
expectError: true, expectError: true,
}, },
{ {
desc: "Succeeds when there are no servers", desc: "Succeeds when there are no servers",
serviceName: "test", serviceName: "test",
service: &dynamic.ServersLoadBalancer{}, service: &dynamic.ServersLoadBalancer{},
fwd: &MockForwarder{}, fwd: &forwarderMock{},
expectError: false, expectError: false,
}, },
{ {
@ -56,7 +58,7 @@ func TestGetLoadBalancer(t *testing.T) {
service: &dynamic.ServersLoadBalancer{ service: &dynamic.ServersLoadBalancer{
Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}}, Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}},
}, },
fwd: &MockForwarder{}, fwd: &forwarderMock{},
expectError: false, expectError: false,
}, },
} }
@ -79,11 +81,8 @@ func TestGetLoadBalancer(t *testing.T) {
} }
func TestGetLoadBalancerServiceHandler(t *testing.T) { func TestGetLoadBalancerServiceHandler(t *testing.T) {
sm := NewManager(nil, nil, nil, &RoundTripperManager{ pb := httputil.NewProxyBuilder(&transportManagerMock{}, nil)
roundTrippers: map[string]http.RoundTripper{ sm := NewManager(nil, nil, nil, transportManagerMock{}, pb)
"default@internal": http.DefaultTransport,
},
})
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "first") w.Header().Set("X-From", "first")
@ -139,7 +138,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
desc: "Load balances between the two servers", desc: "Load balances between the two servers",
serviceName: "test", serviceName: "test",
service: &dynamic.ServersLoadBalancer{ service: &dynamic.ServersLoadBalancer{
PassHostHeader: Bool(true), PassHostHeader: boolPtr(true),
Servers: []dynamic.Server{ Servers: []dynamic.Server{
{ {
URL: server1.URL, URL: server1.URL,
@ -254,7 +253,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
desc: "PassHost doesn't pass the host instead of the IP", desc: "PassHost doesn't pass the host instead of the IP",
serviceName: "test", serviceName: "test",
service: &dynamic.ServersLoadBalancer{ service: &dynamic.ServersLoadBalancer{
PassHostHeader: Bool(false), PassHostHeader: boolPtr(false),
Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}}, Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}},
Servers: []dynamic.Server{ Servers: []dynamic.Server{
{ {
@ -359,11 +358,8 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
// This test is an adapted version of net/http/httputil.Test1xxResponses test. // This test is an adapted version of net/http/httputil.Test1xxResponses test.
func Test1xxResponses(t *testing.T) { func Test1xxResponses(t *testing.T) {
sm := NewManager(nil, nil, nil, &RoundTripperManager{ pb := httputil.NewProxyBuilder(&transportManagerMock{}, nil)
roundTrippers: map[string]http.RoundTripper{ sm := NewManager(nil, nil, nil, &transportManagerMock{}, pb)
"default@internal": http.DefaultTransport,
},
})
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header() h := w.Header()
@ -499,11 +495,7 @@ func TestManager_Build(t *testing.T) {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
t.Parallel() t.Parallel()
manager := NewManager(test.configs, nil, nil, &RoundTripperManager{ manager := NewManager(test.configs, nil, nil, &transportManagerMock{}, nil)
roundTrippers: map[string]http.RoundTripper{
"default@internal": http.DefaultTransport,
},
})
ctx := context.Background() ctx := context.Background()
if len(test.providerName) > 0 { if len(test.providerName) > 0 {
@ -526,30 +518,30 @@ func TestMultipleTypeOnBuildHTTP(t *testing.T) {
}, },
} }
manager := NewManager(services, nil, nil, &RoundTripperManager{ manager := NewManager(services, nil, nil, &transportManagerMock{}, nil)
roundTrippers: map[string]http.RoundTripper{
"default@internal": http.DefaultTransport,
},
})
_, err := manager.BuildHTTP(context.Background(), "test@file") _, err := manager.BuildHTTP(context.Background(), "test@file")
assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead") assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead")
} }
func Bool(v bool) *bool { return &v } func boolPtr(v bool) *bool { return &v }
type MockForwarder struct{} type forwarderMock struct{}
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) { func (forwarderMock) ServeHTTP(http.ResponseWriter, *http.Request) {
panic("not available") panic("not available")
} }
type rtMock struct{} type transportManagerMock struct{}
func newRtMock() RoundTripperGetter { func (t transportManagerMock) GetRoundTripper(_ string) (http.RoundTripper, error) {
return &rtMock{} return &http.Transport{}, nil
} }
func (r *rtMock) Get(_ string) (http.RoundTripper, error) { func (t transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
return http.DefaultTransport, nil return nil, nil
}
func (t transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
return &dynamic.ServersTransport{}, nil
} }

View file

@ -11,6 +11,15 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) { func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) {
transportHTTP1 := transport.Clone() transportHTTP1 := transport.Clone()

View file

@ -22,52 +22,45 @@ import (
"github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/config/dynamic"
traefiktls "github.com/traefik/traefik/v3/pkg/tls" traefiktls "github.com/traefik/traefik/v3/pkg/tls"
"github.com/traefik/traefik/v3/pkg/types" "github.com/traefik/traefik/v3/pkg/types"
"golang.org/x/net/http2"
) )
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
// SpiffeX509Source allows to retrieve a x509 SVID and bundle. // SpiffeX509Source allows to retrieve a x509 SVID and bundle.
type SpiffeX509Source interface { type SpiffeX509Source interface {
x509svid.Source x509svid.Source
x509bundle.Source x509bundle.Source
} }
// NewRoundTripperManager creates a new RoundTripperManager. // TransportManager handles transports for backend communication.
func NewRoundTripperManager(spiffeX509Source SpiffeX509Source) *RoundTripperManager { type TransportManager struct {
return &RoundTripperManager{
roundTrippers: make(map[string]http.RoundTripper),
configs: make(map[string]*dynamic.ServersTransport),
spiffeX509Source: spiffeX509Source,
}
}
// RoundTripperManager handles roundtripper for the reverse proxy.
type RoundTripperManager struct {
rtLock sync.RWMutex rtLock sync.RWMutex
roundTrippers map[string]http.RoundTripper roundTrippers map[string]http.RoundTripper
configs map[string]*dynamic.ServersTransport configs map[string]*dynamic.ServersTransport
tlsConfigs map[string]*tls.Config
spiffeX509Source SpiffeX509Source spiffeX509Source SpiffeX509Source
} }
// Update updates the roundtrippers configurations. // NewTransportManager creates a new TransportManager.
func (r *RoundTripperManager) Update(newConfigs map[string]*dynamic.ServersTransport) { func NewTransportManager(spiffeX509Source SpiffeX509Source) *TransportManager {
r.rtLock.Lock() return &TransportManager{
defer r.rtLock.Unlock() roundTrippers: make(map[string]http.RoundTripper),
configs: make(map[string]*dynamic.ServersTransport),
tlsConfigs: make(map[string]*tls.Config),
spiffeX509Source: spiffeX509Source,
}
}
for configName, config := range r.configs { // Update updates the transport configurations.
func (t *TransportManager) Update(newConfigs map[string]*dynamic.ServersTransport) {
t.rtLock.Lock()
defer t.rtLock.Unlock()
for configName, config := range t.configs {
newConfig, ok := newConfigs[configName] newConfig, ok := newConfigs[configName]
if !ok { if !ok {
delete(r.configs, configName) delete(t.configs, configName)
delete(r.roundTrippers, configName) delete(t.roundTrippers, configName)
delete(t.tlsConfigs, configName)
continue continue
} }
@ -76,50 +69,133 @@ func (r *RoundTripperManager) Update(newConfigs map[string]*dynamic.ServersTrans
} }
var err error var err error
r.roundTrippers[configName], err = r.createRoundTripper(newConfig)
var tlsConfig *tls.Config
if tlsConfig, err = t.createTLSConfig(newConfig); err != nil {
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s TLS configuration, fallback on default TLS config", configName)
}
t.tlsConfigs[configName] = tlsConfig
t.roundTrippers[configName], err = t.createRoundTripper(newConfig, tlsConfig)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", configName) log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", configName)
r.roundTrippers[configName] = http.DefaultTransport t.roundTrippers[configName] = http.DefaultTransport
} }
} }
for newConfigName, newConfig := range newConfigs { for newConfigName, newConfig := range newConfigs {
if _, ok := r.configs[newConfigName]; ok { if _, ok := t.configs[newConfigName]; ok {
continue continue
} }
var err error var err error
r.roundTrippers[newConfigName], err = r.createRoundTripper(newConfig)
var tlsConfig *tls.Config
if tlsConfig, err = t.createTLSConfig(newConfig); err != nil {
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s TLS configuration, fallback on default TLS config", newConfigName)
}
t.tlsConfigs[newConfigName] = tlsConfig
t.roundTrippers[newConfigName], err = t.createRoundTripper(newConfig, tlsConfig)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", newConfigName) log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", newConfigName)
r.roundTrippers[newConfigName] = http.DefaultTransport t.roundTrippers[newConfigName] = http.DefaultTransport
} }
} }
r.configs = newConfigs t.configs = newConfigs
} }
// Get gets a roundtripper by name. // GetRoundTripper gets a roundtripper corresponding to the given transport name.
func (r *RoundTripperManager) Get(name string) (http.RoundTripper, error) { func (t *TransportManager) GetRoundTripper(name string) (http.RoundTripper, error) {
if len(name) == 0 { if len(name) == 0 {
name = "default@internal" name = "default@internal"
} }
r.rtLock.RLock() t.rtLock.RLock()
defer r.rtLock.RUnlock() defer t.rtLock.RUnlock()
if rt, ok := r.roundTrippers[name]; ok { if rt, ok := t.roundTrippers[name]; ok {
return rt, nil return rt, nil
} }
return nil, fmt.Errorf("servers transport not found %s", name) return nil, fmt.Errorf("servers transport not found %s", name)
} }
// Get gets transport by name.
func (t *TransportManager) Get(name string) (*dynamic.ServersTransport, error) {
if len(name) == 0 {
name = "default@internal"
}
t.rtLock.RLock()
defer t.rtLock.RUnlock()
if rt, ok := t.configs[name]; ok {
return rt, nil
}
return nil, fmt.Errorf("servers transport not found %s", name)
}
// GetTLSConfig gets a TLS config corresponding to the given transport name.
func (t *TransportManager) GetTLSConfig(name string) (*tls.Config, error) {
if len(name) == 0 {
name = "default@internal"
}
t.rtLock.RLock()
defer t.rtLock.RUnlock()
if rt, ok := t.tlsConfigs[name]; ok {
return rt, nil
}
return nil, fmt.Errorf("tls config not found %s", name)
}
func (t *TransportManager) createTLSConfig(cfg *dynamic.ServersTransport) (*tls.Config, error) {
var config *tls.Config
if cfg.Spiffe != nil {
if t.spiffeX509Source == nil {
return nil, errors.New("SPIFFE is enabled for this transport, but not configured")
}
spiffeAuthorizer, err := buildSpiffeAuthorizer(cfg.Spiffe)
if err != nil {
return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
}
config = tlsconfig.MTLSClientConfig(t.spiffeX509Source, t.spiffeX509Source, spiffeAuthorizer)
}
if cfg.InsecureSkipVerify || len(cfg.RootCAs) > 0 || len(cfg.ServerName) > 0 || len(cfg.Certificates) > 0 || cfg.PeerCertURI != "" {
if config != nil {
return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
}
config = &tls.Config{
ServerName: cfg.ServerName,
InsecureSkipVerify: cfg.InsecureSkipVerify,
RootCAs: createRootCACertPool(cfg.RootCAs),
Certificates: cfg.Certificates.GetCertificates(),
}
if cfg.PeerCertURI != "" {
config.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return traefiktls.VerifyPeerCertificate(cfg.PeerCertURI, config, rawCerts)
}
}
}
return config, nil
}
// createRoundTripper creates an http.RoundTripper configured with the Transport configuration settings. // createRoundTripper creates an http.RoundTripper configured with the Transport configuration settings.
// For the settings that can't be configured in Traefik it uses the default http.Transport settings. // For the settings that can't be configured in Traefik it uses the default http.Transport settings.
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost in Traefik at this point in time. // An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost in Traefik at this point in time.
// Setting this value to the default of 100 could lead to confusing behavior and backwards compatibility issues. // Setting this value to the default of 100 could lead to confusing behavior and backwards compatibility issues.
func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) (http.RoundTripper, error) { func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tlsConfig *tls.Config) (http.RoundTripper, error) {
if cfg == nil { if cfg == nil {
return nil, errors.New("no transport configuration given") return nil, errors.New("no transport configuration given")
} }
@ -142,6 +218,7 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
ReadBufferSize: 64 * 1024, ReadBufferSize: 64 * 1024,
WriteBufferSize: 64 * 1024, WriteBufferSize: 64 * 1024,
TLSClientConfig: tlsConfig,
} }
if cfg.ForwardingTimeouts != nil { if cfg.ForwardingTimeouts != nil {
@ -149,41 +226,9 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
transport.IdleConnTimeout = time.Duration(cfg.ForwardingTimeouts.IdleConnTimeout) transport.IdleConnTimeout = time.Duration(cfg.ForwardingTimeouts.IdleConnTimeout)
} }
if cfg.Spiffe != nil {
if r.spiffeX509Source == nil {
return nil, errors.New("SPIFFE is enabled for this transport, but not configured")
}
spiffeAuthorizer, err := buildSpiffeAuthorizer(cfg.Spiffe)
if err != nil {
return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
}
transport.TLSClientConfig = tlsconfig.MTLSClientConfig(r.spiffeX509Source, r.spiffeX509Source, spiffeAuthorizer)
}
if cfg.InsecureSkipVerify || len(cfg.RootCAs) > 0 || len(cfg.ServerName) > 0 || len(cfg.Certificates) > 0 || cfg.PeerCertURI != "" {
if transport.TLSClientConfig != nil {
return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
}
transport.TLSClientConfig = &tls.Config{
ServerName: cfg.ServerName,
InsecureSkipVerify: cfg.InsecureSkipVerify,
RootCAs: createRootCACertPool(cfg.RootCAs),
Certificates: cfg.Certificates.GetCertificates(),
}
if cfg.PeerCertURI != "" {
transport.TLSClientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return traefiktls.VerifyPeerCertificate(cfg.PeerCertURI, transport.TLSClientConfig, rawCerts)
}
}
}
// Return directly HTTP/1.1 transport when HTTP/2 is disabled // Return directly HTTP/1.1 transport when HTTP/2 is disabled
if cfg.DisableHTTP2 { if cfg.DisableHTTP2 {
return &KerberosRoundTripper{ return &kerberosRoundTripper{
OriginalRoundTripper: transport, OriginalRoundTripper: transport,
new: func() http.RoundTripper { new: func() http.RoundTripper {
return transport.Clone() return transport.Clone()
@ -195,7 +240,7 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &KerberosRoundTripper{ return &kerberosRoundTripper{
OriginalRoundTripper: rt, OriginalRoundTripper: rt,
new: func() http.RoundTripper { new: func() http.RoundTripper {
return rt.Clone() return rt.Clone()
@ -203,11 +248,6 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
}, nil }, nil
} }
type KerberosRoundTripper struct {
new func() http.RoundTripper
OriginalRoundTripper http.RoundTripper
}
type stickyRoundTripper struct { type stickyRoundTripper struct {
RoundTripper http.RoundTripper RoundTripper http.RoundTripper
} }
@ -220,7 +260,12 @@ func AddTransportOnContext(ctx context.Context) context.Context {
return context.WithValue(ctx, transportKey, &stickyRoundTripper{}) return context.WithValue(ctx, transportKey, &stickyRoundTripper{})
} }
func (k *KerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { type kerberosRoundTripper struct {
new func() http.RoundTripper
OriginalRoundTripper http.RoundTripper
}
func (k *kerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
value, ok := request.Context().Value(transportKey).(*stickyRoundTripper) value, ok := request.Context().Value(transportKey).(*stickyRoundTripper)
if !ok { if !ok {
return k.OriginalRoundTripper.RoundTrip(request) return k.OriginalRoundTripper.RoundTrip(request)

View file

@ -141,7 +141,7 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) {
srv.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} srv.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
srv.StartTLS() srv.StartTLS()
rtManager := NewRoundTripperManager(nil) transportManager := NewTransportManager(nil)
dynamicConf := map[string]*dynamic.ServersTransport{ dynamicConf := map[string]*dynamic.ServersTransport{
"test": { "test": {
@ -151,9 +151,9 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) {
} }
for range 10 { for range 10 {
rtManager.Update(dynamicConf) transportManager.Update(dynamicConf)
tr, err := rtManager.Get("test") tr, err := transportManager.GetRoundTripper("test")
require.NoError(t, err) require.NoError(t, err)
client := http.Client{Transport: tr} client := http.Client{Transport: tr}
@ -173,9 +173,9 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) {
}, },
} }
rtManager.Update(dynamicConf) transportManager.Update(dynamicConf)
tr, err := rtManager.Get("test") tr, err := transportManager.GetRoundTripper("test")
require.NoError(t, err) require.NoError(t, err)
client := http.Client{Transport: tr} client := http.Client{Transport: tr}
@ -209,7 +209,7 @@ func TestMTLS(t *testing.T) {
} }
srv.StartTLS() srv.StartTLS()
rtManager := NewRoundTripperManager(nil) transportManager := NewTransportManager(nil)
dynamicConf := map[string]*dynamic.ServersTransport{ dynamicConf := map[string]*dynamic.ServersTransport{
"test": { "test": {
@ -227,9 +227,9 @@ func TestMTLS(t *testing.T) {
}, },
} }
rtManager.Update(dynamicConf) transportManager.Update(dynamicConf)
tr, err := rtManager.Get("test") tr, err := transportManager.GetRoundTripper("test")
require.NoError(t, err) require.NoError(t, err)
client := http.Client{Transport: tr} client := http.Client{Transport: tr}
@ -348,7 +348,7 @@ func TestSpiffeMTLS(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
rtManager := NewRoundTripperManager(test.clientSource) transportManager := NewTransportManager(test.clientSource)
dynamicConf := map[string]*dynamic.ServersTransport{ dynamicConf := map[string]*dynamic.ServersTransport{
"test": { "test": {
@ -356,9 +356,9 @@ func TestSpiffeMTLS(t *testing.T) {
}, },
} }
rtManager.Update(dynamicConf) transportManager.Update(dynamicConf)
tr, err := rtManager.Get("test") tr, err := transportManager.GetRoundTripper("test")
require.NoError(t, err) require.NoError(t, err)
client := http.Client{Transport: tr} client := http.Client{Transport: tr}
@ -415,7 +415,7 @@ func TestDisableHTTP2(t *testing.T) {
srv.EnableHTTP2 = test.serverHTTP2 srv.EnableHTTP2 = test.serverHTTP2
srv.StartTLS() srv.StartTLS()
rtManager := NewRoundTripperManager(nil) transportManager := NewTransportManager(nil)
dynamicConf := map[string]*dynamic.ServersTransport{ dynamicConf := map[string]*dynamic.ServersTransport{
"test": { "test": {
@ -424,9 +424,9 @@ func TestDisableHTTP2(t *testing.T) {
}, },
} }
rtManager.Update(dynamicConf) transportManager.Update(dynamicConf)
tr, err := rtManager.Get("test") tr, err := transportManager.GetRoundTripper("test")
require.NoError(t, err) require.NoError(t, err)
client := http.Client{Transport: tr} client := http.Client{Transport: tr}
@ -593,7 +593,7 @@ func TestKerberosRoundTripper(t *testing.T) {
origCount := 0 origCount := 0
dedicatedCount := 0 dedicatedCount := 0
rt := KerberosRoundTripper{ rt := kerberosRoundTripper{
new: func() http.RoundTripper { new: func() http.RoundTripper {
return roundTripperFn(func(req *http.Request) (*http.Response, error) { return roundTripperFn(func(req *http.Request) (*http.Response, error) {
dedicatedCount++ dedicatedCount++