Introduce a fast proxy mode to improve HTTP/1.1 performances with backends
Co-authored-by: Romain <rtribotte@users.noreply.github.com> Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
parent
a6db1cac37
commit
f8a78b3b25
39 changed files with 3173 additions and 378 deletions
|
@ -229,7 +229,7 @@ issues:
|
|||
text: 'struct-tag: unknown option ''inline'' in JSON tag'
|
||||
linters:
|
||||
- revive
|
||||
- path: pkg/server/service/bufferpool.go
|
||||
- path: pkg/proxy/httputil/bufferpool.go
|
||||
text: 'SA6002: argument should be pointer-like to avoid allocations'
|
||||
- path: pkg/server/middleware/middlewares.go
|
||||
text: "Function 'buildConstructor' has too many statements"
|
||||
|
|
|
@ -37,6 +37,8 @@ import (
|
|||
"github.com/traefik/traefik/v3/pkg/provider/aggregator"
|
||||
"github.com/traefik/traefik/v3/pkg/provider/tailscale"
|
||||
"github.com/traefik/traefik/v3/pkg/provider/traefik"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/traefik/traefik/v3/pkg/safe"
|
||||
"github.com/traefik/traefik/v3/pkg/server"
|
||||
"github.com/traefik/traefik/v3/pkg/server/middleware"
|
||||
|
@ -281,10 +283,16 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err
|
|||
log.Info().Msg("Successfully obtained SPIFFE SVID.")
|
||||
}
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(spiffeX509Source)
|
||||
transportManager := service.NewTransportManager(spiffeX509Source)
|
||||
|
||||
var proxyBuilder service.ProxyBuilder = httputil.NewProxyBuilder(transportManager, semConvMetricRegistry)
|
||||
if staticConfiguration.Experimental != nil && staticConfiguration.Experimental.FastProxy != nil {
|
||||
proxyBuilder = proxy.NewSmartBuilder(transportManager, proxyBuilder, *staticConfiguration.Experimental.FastProxy)
|
||||
}
|
||||
|
||||
dialerManager := tcp.NewDialerManager(spiffeX509Source)
|
||||
acmeHTTPHandler := getHTTPChallengeHandler(acmeProviders, httpChallengeProvider)
|
||||
managerFactory := service.NewManagerFactory(*staticConfiguration, routinesPool, observabilityMgr, roundTripperManager, acmeHTTPHandler)
|
||||
managerFactory := service.NewManagerFactory(*staticConfiguration, routinesPool, observabilityMgr, transportManager, proxyBuilder, acmeHTTPHandler)
|
||||
|
||||
// Router factory
|
||||
|
||||
|
@ -318,7 +326,8 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err
|
|||
|
||||
// Server Transports
|
||||
watcher.AddListener(func(conf dynamic.Configuration) {
|
||||
roundTripperManager.Update(conf.HTTP.ServersTransports)
|
||||
transportManager.Update(conf.HTTP.ServersTransports)
|
||||
proxyBuilder.Update(conf.HTTP.ServersTransports)
|
||||
dialerManager.Update(conf.TCP.ServersTransports)
|
||||
})
|
||||
|
||||
|
|
|
@ -228,6 +228,12 @@ WriteTimeout is the maximum duration before timing out writes of the response. I
|
|||
`--entrypoints.<name>.udp.timeout`:
|
||||
Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```)
|
||||
|
||||
`--experimental.fastproxy`:
|
||||
Enable the FastProxy implementation. (Default: ```false```)
|
||||
|
||||
`--experimental.fastproxy.debug`:
|
||||
Enable debug mode for the FastProxy implementation. (Default: ```false```)
|
||||
|
||||
`--experimental.kubernetesgateway`:
|
||||
(Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```)
|
||||
|
||||
|
|
|
@ -228,6 +228,12 @@ WriteTimeout is the maximum duration before timing out writes of the response. I
|
|||
`TRAEFIK_ENTRYPOINTS_<NAME>_UDP_TIMEOUT`:
|
||||
Timeout defines how long to wait on an idle session before releasing the related resources. (Default: ```3```)
|
||||
|
||||
`TRAEFIK_EXPERIMENTAL_FASTPROXY`:
|
||||
Enable the FastProxy implementation. (Default: ```false```)
|
||||
|
||||
`TRAEFIK_EXPERIMENTAL_FASTPROXY_DEBUG`:
|
||||
Enable debug mode for the FastProxy implementation. (Default: ```false```)
|
||||
|
||||
`TRAEFIK_EXPERIMENTAL_KUBERNETESGATEWAY`:
|
||||
(Deprecated) Allow the Kubernetes gateway api provider usage. (Default: ```false```)
|
||||
|
||||
|
|
|
@ -509,6 +509,8 @@
|
|||
[experimental.localPlugins.LocalDescriptor1.settings]
|
||||
envs = ["foobar", "foobar"]
|
||||
mounts = ["foobar", "foobar"]
|
||||
[experimental.fastProxy]
|
||||
debug = true
|
||||
|
||||
[core]
|
||||
defaultRuleSyntax = "foobar"
|
||||
|
|
|
@ -572,6 +572,8 @@ experimental:
|
|||
mounts:
|
||||
- foobar
|
||||
- foobar
|
||||
fastProxy:
|
||||
debug: true
|
||||
kubernetesGateway: true
|
||||
core:
|
||||
defaultRuleSyntax: foobar
|
||||
|
|
41
docs/content/user-guides/fastproxy.md
Normal file
41
docs/content/user-guides/fastproxy.md
Normal file
|
@ -0,0 +1,41 @@
|
|||
---
|
||||
title: "Traefik FastProxy Experimental Configuration"
|
||||
description: "This section of the Traefik Proxy documentation explains how to use the new FastProxy option."
|
||||
---
|
||||
|
||||
# Traefik FastProxy Experimental Configuration
|
||||
|
||||
## Overview
|
||||
|
||||
This guide provides instructions on how to configure and use the new experimental `fastProxy` static configuration option in Traefik.
|
||||
The `fastProxy` option introduces a high-performance reverse proxy designed to enhance the performance of routing.
|
||||
|
||||
!!! info "Limitations"
|
||||
|
||||
Please note that the new fast proxy implementation does not work with HTTP/2.
|
||||
This means that when a H2C or HTTPS request with [HTTP2 enabled](../routing/services/index.md#disablehttp2) is sent to a backend, the fallback proxy is the regular one.
|
||||
|
||||
Additionnaly, observability features like tracing and OTEL semconv metrics are not supported for the moment.
|
||||
|
||||
!!! warning "Experimental"
|
||||
|
||||
The `fastProxy` option is currently experimental and subject to change in future releases.
|
||||
Use with caution in production environments.
|
||||
|
||||
### Enabling FastProxy
|
||||
|
||||
The fastProxy option is a static configuration parameter.
|
||||
To enable it, you need to configure it in your Traefik static configuration
|
||||
|
||||
```yaml tab="File (YAML)"
|
||||
experimental:
|
||||
fastProxy: {}
|
||||
```
|
||||
|
||||
```toml tab="File (TOML)"
|
||||
[experimental.fastProxy]
|
||||
```
|
||||
|
||||
```bash tab="CLI"
|
||||
--experimental.fastProxy
|
||||
```
|
|
@ -163,6 +163,7 @@ nav:
|
|||
- 'Overview': 'observability/tracing/overview.md'
|
||||
- 'OpenTelemetry': 'observability/tracing/opentelemetry.md'
|
||||
- 'User Guides':
|
||||
- 'FastProxy': 'user-guides/fastproxy.md'
|
||||
- 'Kubernetes and Let''s Encrypt': 'user-guides/crd-acme/index.md'
|
||||
- 'gRPC Examples': 'user-guides/grpc.md'
|
||||
- 'Docker':
|
||||
|
|
8
go.mod
8
go.mod
|
@ -6,7 +6,7 @@ require (
|
|||
github.com/BurntSushi/toml v1.4.0
|
||||
github.com/Masterminds/sprig/v3 v3.2.3
|
||||
github.com/abbot/go-http-auth v0.0.0-00010101000000-000000000000 // No tag on the repo.
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/andybalholm/brotli v1.1.0
|
||||
github.com/aws/aws-sdk-go v1.44.327
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // No tag on the repo.
|
||||
|
@ -102,6 +102,11 @@ require (
|
|||
sigs.k8s.io/yaml v1.4.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
|
||||
github.com/valyala/fasthttp v1.55.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
|
@ -315,6 +320,7 @@ require (
|
|||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/transip/gotransip/v6 v6.23.0 // indirect
|
||||
github.com/ultradns/ultradns-go-sdk v1.6.1-20231103022937-8589b6a // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/vinyldns/go-vinyldns v0.9.16 // indirect
|
||||
github.com/vultr/govultr/v3 v3.9.0 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
|
|
7
go.sum
7
go.sum
|
@ -98,8 +98,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
|
|||
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.712 h1:lM7JnA9dEdDFH9XOgRNQMDTQnOjlLkDTNA7c0aWTQ30=
|
||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.712/go.mod h1:SOSDHfe1kX91v3W5QiBsWSLqeLxImobbMX1mxrFHsVQ=
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
|
||||
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
|
||||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
|
||||
|
@ -1041,7 +1041,10 @@ github.com/unrolled/secure v1.0.9 h1:BWRuEb1vDrBFFDdbCnKkof3gZ35I/bnHGyt0LB0TNyQ
|
|||
github.com/unrolled/secure v1.0.9/go.mod h1:fO+mEan+FLB0CdEnHf6Q4ZZVNqG+5fuLFnP8p0BXDPI=
|
||||
github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc=
|
||||
github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8=
|
||||
github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM=
|
||||
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ=
|
||||
|
|
35
integration/fixtures/simple_fastproxy.toml
Normal file
35
integration/fixtures/simple_fastproxy.toml
Normal file
|
@ -0,0 +1,35 @@
|
|||
[global]
|
||||
checkNewVersion = false
|
||||
sendAnonymousUsage = false
|
||||
|
||||
[log]
|
||||
level = "DEBUG"
|
||||
noColor = true
|
||||
|
||||
[entryPoints]
|
||||
[entryPoints.web]
|
||||
address = ":8000"
|
||||
|
||||
[api]
|
||||
insecure = true
|
||||
|
||||
[providers.file]
|
||||
filename = "{{ .SelfFilename }}"
|
||||
|
||||
[experimental]
|
||||
[experimental.fastProxy]
|
||||
debug = true
|
||||
|
||||
## dynamic configuration ##
|
||||
|
||||
[http.routers]
|
||||
[http.routers.router1]
|
||||
entrypoints = ["web"]
|
||||
service = "service1"
|
||||
rule = "PathPrefix(`/`)"
|
||||
|
||||
[http.services]
|
||||
[http.services.service1]
|
||||
[http.services.service1.loadBalancer]
|
||||
[[http.services.service1.loadBalancer.servers]]
|
||||
url = "{{ .Server }}"
|
|
@ -65,6 +65,32 @@ func (s *SimpleSuite) TestSimpleDefaultConfig() {
|
|||
require.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *SimpleSuite) TestSimpleFastProxy() {
|
||||
var callCount int
|
||||
srv1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
assert.Contains(s.T(), req.Header, "X-Traefik-Fast-Proxy")
|
||||
callCount++
|
||||
}))
|
||||
defer srv1.Close()
|
||||
|
||||
file := s.adaptFile("fixtures/simple_fastproxy.toml", struct {
|
||||
Server string
|
||||
}{
|
||||
Server: srv1.URL,
|
||||
})
|
||||
|
||||
s.traefikCmd(withConfigFile(file), "--log.level=DEBUG")
|
||||
|
||||
// wait for traefik
|
||||
err := try.GetRequest("http://127.0.0.1:8080/api/rawdata", 10*time.Second, try.BodyContains("127.0.0.1"))
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
err = try.GetRequest("http://127.0.0.1:8000/", time.Second)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
assert.GreaterOrEqual(s.T(), 1, callCount)
|
||||
}
|
||||
|
||||
func (s *SimpleSuite) TestWithWebConfig() {
|
||||
s.cmdTraefik(withConfigFile("fixtures/simple_web.toml"))
|
||||
|
||||
|
|
|
@ -7,6 +7,13 @@ type Experimental struct {
|
|||
Plugins map[string]plugins.Descriptor `description:"Plugins configuration." json:"plugins,omitempty" toml:"plugins,omitempty" yaml:"plugins,omitempty" export:"true"`
|
||||
LocalPlugins map[string]plugins.LocalDescriptor `description:"Local plugins configuration." json:"localPlugins,omitempty" toml:"localPlugins,omitempty" yaml:"localPlugins,omitempty" export:"true"`
|
||||
|
||||
FastProxy *FastProxyConfig `description:"Enable the FastProxy implementation." json:"fastProxy,omitempty" toml:"fastProxy,omitempty" yaml:"fastProxy,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
|
||||
|
||||
// Deprecated: KubernetesGateway provider is not an experimental feature starting with v3.1. Please remove its usage from the static configuration.
|
||||
KubernetesGateway bool `description:"(Deprecated) Allow the Kubernetes gateway api provider usage." json:"kubernetesGateway,omitempty" toml:"kubernetesGateway,omitempty" yaml:"kubernetesGateway,omitempty" export:"true"`
|
||||
}
|
||||
|
||||
// FastProxyConfig holds the FastProxy configuration.
|
||||
type FastProxyConfig struct {
|
||||
Debug bool `description:"Enable debug mode for the FastProxy implementation." json:"debug,omitempty" toml:"debug,omitempty" yaml:"debug,omitempty" export:"true"`
|
||||
}
|
||||
|
|
129
pkg/proxy/fast/builder.go
Normal file
129
pkg/proxy/fast/builder.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/config/static"
|
||||
)
|
||||
|
||||
// TransportManager manages transport used for backend communications.
|
||||
type TransportManager interface {
|
||||
Get(name string) (*dynamic.ServersTransport, error)
|
||||
GetTLSConfig(name string) (*tls.Config, error)
|
||||
}
|
||||
|
||||
// ProxyBuilder handles the connection pools for the FastProxy proxies.
|
||||
type ProxyBuilder struct {
|
||||
debug bool
|
||||
transportManager TransportManager
|
||||
|
||||
// lock isn't needed because ProxyBuilder is not called concurrently.
|
||||
pools map[string]map[string]*connPool
|
||||
proxy func(*http.Request) (*url.URL, error)
|
||||
|
||||
// not goroutine safe.
|
||||
configs map[string]*dynamic.ServersTransport
|
||||
}
|
||||
|
||||
// NewProxyBuilder creates a new ProxyBuilder.
|
||||
func NewProxyBuilder(transportManager TransportManager, config static.FastProxyConfig) *ProxyBuilder {
|
||||
return &ProxyBuilder{
|
||||
debug: config.Debug,
|
||||
transportManager: transportManager,
|
||||
pools: make(map[string]map[string]*connPool),
|
||||
proxy: http.ProxyFromEnvironment,
|
||||
configs: make(map[string]*dynamic.ServersTransport),
|
||||
}
|
||||
}
|
||||
|
||||
// Update updates all the round-tripper corresponding to the given configs.
|
||||
// This method must not be used concurrently.
|
||||
func (r *ProxyBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
||||
for configName := range r.configs {
|
||||
if _, ok := newConfigs[configName]; !ok {
|
||||
for _, c := range r.pools[configName] {
|
||||
c.Close()
|
||||
}
|
||||
delete(r.pools, configName)
|
||||
}
|
||||
}
|
||||
|
||||
for newConfigName, newConfig := range newConfigs {
|
||||
if !reflect.DeepEqual(newConfig, r.configs[newConfigName]) {
|
||||
for _, c := range r.pools[newConfigName] {
|
||||
c.Close()
|
||||
}
|
||||
delete(r.pools, newConfigName)
|
||||
}
|
||||
}
|
||||
|
||||
r.configs = newConfigs
|
||||
}
|
||||
|
||||
// Build builds a new ReverseProxy with the given configuration.
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader bool) (http.Handler, error) {
|
||||
proxyURL, err := r.proxy(&http.Request{URL: targetURL})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting proxy: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := r.transportManager.Get(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
||||
}
|
||||
|
||||
var responseHeaderTimeout time.Duration
|
||||
if cfg.ForwardingTimeouts != nil {
|
||||
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting TLS config: %w", err)
|
||||
}
|
||||
|
||||
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
|
||||
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, responseHeaderTimeout, pool)
|
||||
}
|
||||
|
||||
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
|
||||
pool, ok := r.pools[cfgName]
|
||||
if !ok {
|
||||
pool = make(map[string]*connPool)
|
||||
r.pools[cfgName] = pool
|
||||
}
|
||||
|
||||
if connPool, ok := pool[targetURL.String()]; ok {
|
||||
return connPool
|
||||
}
|
||||
|
||||
idleConnTimeout := 90 * time.Second
|
||||
dialTimeout := 30 * time.Second
|
||||
if config.ForwardingTimeouts != nil {
|
||||
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
|
||||
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
|
||||
}
|
||||
|
||||
proxyDialer := newDialer(dialerConfig{
|
||||
DialKeepAlive: 0,
|
||||
DialTimeout: dialTimeout,
|
||||
HTTP: true,
|
||||
TLS: targetURL.Scheme == "https",
|
||||
ProxyURL: proxyURL,
|
||||
}, tlsConfig)
|
||||
|
||||
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) {
|
||||
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
|
||||
})
|
||||
|
||||
r.pools[cfgName][targetURL.String()] = connPool
|
||||
|
||||
return connPool
|
||||
}
|
163
pkg/proxy/fast/connpool.go
Normal file
163
pkg/proxy/fast/connpool.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// conn is an enriched net.Conn.
|
||||
type conn struct {
|
||||
net.Conn
|
||||
|
||||
idleAt time.Time // the last time it was marked as idle.
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
func (c *conn) isExpired() bool {
|
||||
expTime := c.idleAt.Add(c.idleTimeout)
|
||||
return c.idleTimeout > 0 && time.Now().After(expTime)
|
||||
}
|
||||
|
||||
// connPool is a net.Conn pool implementation using channels.
|
||||
type connPool struct {
|
||||
dialer func() (net.Conn, error)
|
||||
idleConns chan *conn
|
||||
idleConnTimeout time.Duration
|
||||
ticker *time.Ticker
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// newConnPool creates a new connPool.
|
||||
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
|
||||
c := &connPool{
|
||||
dialer: dialer,
|
||||
idleConns: make(chan *conn, maxIdleConn),
|
||||
idleConnTimeout: idleConnTimeout,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
if idleConnTimeout > 0 {
|
||||
c.ticker = time.NewTicker(c.idleConnTimeout / 2)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ticker.C:
|
||||
c.cleanIdleConns()
|
||||
case <-c.doneCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Close closes stop the cleanIdleConn goroutine.
|
||||
func (c *connPool) Close() {
|
||||
if c.idleConnTimeout > 0 {
|
||||
close(c.doneCh)
|
||||
c.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireConn returns an idle net.Conn from the pool.
|
||||
func (c *connPool) AcquireConn() (*conn, error) {
|
||||
for {
|
||||
co, err := c.acquireConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !co.isExpired() {
|
||||
return co, nil
|
||||
}
|
||||
|
||||
// As the acquired conn is expired we can close it
|
||||
// without putting it again into the pool.
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReleaseConn releases the given net.Conn to the pool.
|
||||
func (c *connPool) ReleaseConn(co *conn) {
|
||||
co.idleAt = time.Now()
|
||||
c.releaseConn(co)
|
||||
}
|
||||
|
||||
// cleanIdleConns is a routine cleaning the expired connections at a regular basis.
|
||||
func (c *connPool) cleanIdleConns() {
|
||||
for {
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
if !co.isExpired() {
|
||||
c.releaseConn(co)
|
||||
return
|
||||
}
|
||||
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
}
|
||||
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPool) acquireConn() (*conn, error) {
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
return co, nil
|
||||
|
||||
default:
|
||||
errCh := make(chan error, 1)
|
||||
go c.askForNewConn(errCh)
|
||||
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
return co, nil
|
||||
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPool) releaseConn(co *conn) {
|
||||
select {
|
||||
case c.idleConns <- co:
|
||||
|
||||
// Hitting the default case means that we have reached the maximum number of idle
|
||||
// connections, so we can close it.
|
||||
default:
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPool) askForNewConn(errCh chan<- error) {
|
||||
co, err := c.dialer()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("create conn: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.releaseConn(&conn{
|
||||
Conn: co,
|
||||
idleAt: time.Now(),
|
||||
idleTimeout: c.idleConnTimeout,
|
||||
})
|
||||
}
|
184
pkg/proxy/fast/connpool_test.go
Normal file
184
pkg/proxy/fast/connpool_test.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnPool_ConnReuse(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
poolFn func(pool *connPool)
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "One connection",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c1)
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
desc: "Two connections with release",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c1)
|
||||
|
||||
c2, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c2)
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
desc: "Two concurrent connections",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
c2, _ := pool.AcquireConn()
|
||||
|
||||
pool.ReleaseConn(c1)
|
||||
pool.ReleaseConn(c2)
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var connAlloc int
|
||||
dialer := func() (net.Conn, error) {
|
||||
connAlloc++
|
||||
return &net.TCPConn{}, nil
|
||||
}
|
||||
|
||||
pool := newConnPool(2, 0, dialer)
|
||||
test.poolFn(pool)
|
||||
|
||||
assert.Equal(t, test.expected, connAlloc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnPool_MaxIdleConn(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
poolFn func(pool *connPool)
|
||||
maxIdleConn int
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "One connection",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c1)
|
||||
},
|
||||
maxIdleConn: 1,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
desc: "Multiple connections with defered release",
|
||||
poolFn: func(pool *connPool) {
|
||||
for range 7 {
|
||||
c, _ := pool.AcquireConn()
|
||||
defer pool.ReleaseConn(c)
|
||||
}
|
||||
},
|
||||
maxIdleConn: 5,
|
||||
expected: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var keepOpenedConn int
|
||||
dialer := func() (net.Conn, error) {
|
||||
keepOpenedConn++
|
||||
return &mockConn{closeFn: func() error {
|
||||
keepOpenedConn--
|
||||
return nil
|
||||
}}, nil
|
||||
}
|
||||
|
||||
pool := newConnPool(test.maxIdleConn, 0, dialer)
|
||||
test.poolFn(pool)
|
||||
|
||||
assert.Equal(t, test.expected, keepOpenedConn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGC(t *testing.T) {
|
||||
var isDestroyed bool
|
||||
pools := map[string]*connPool{}
|
||||
dialer := func() (net.Conn, error) {
|
||||
c := &mockConn{closeFn: func() error {
|
||||
return nil
|
||||
}}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
pools["test"] = newConnPool(10, 1*time.Second, dialer)
|
||||
runtime.SetFinalizer(pools["test"], func(p *connPool) {
|
||||
isDestroyed = true
|
||||
})
|
||||
c, err := pools["test"].AcquireConn()
|
||||
require.NoError(t, err)
|
||||
|
||||
pools["test"].ReleaseConn(c)
|
||||
|
||||
pools["test"].Close()
|
||||
|
||||
delete(pools, "test")
|
||||
|
||||
runtime.GC()
|
||||
|
||||
require.True(t, isDestroyed)
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(_ []byte) (n int, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) Write(_ []byte) (n int, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) Close() error {
|
||||
if m.closeFn != nil {
|
||||
return m.closeFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) RemoteAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) SetDeadline(_ time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) SetReadDeadline(_ time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) SetWriteDeadline(_ time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
195
pkg/proxy/fast/dialer.go
Normal file
195
pkg/proxy/fast/dialer.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
schemeHTTP = "http"
|
||||
schemeHTTPS = "https"
|
||||
schemeSocks5 = "socks5"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, addr string) (c net.Conn, err error)
|
||||
}
|
||||
|
||||
type dialerFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
func (d dialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||
return d(network, addr)
|
||||
}
|
||||
|
||||
type dialerConfig struct {
|
||||
DialKeepAlive time.Duration
|
||||
DialTimeout time.Duration
|
||||
ProxyURL *url.URL
|
||||
HTTP bool
|
||||
TLS bool
|
||||
}
|
||||
|
||||
func newDialer(cfg dialerConfig, tlsConfig *tls.Config) dialer {
|
||||
if cfg.ProxyURL == nil {
|
||||
return buildDialer(cfg, tlsConfig, cfg.TLS)
|
||||
}
|
||||
|
||||
proxyDialer := buildDialer(cfg, tlsConfig, cfg.ProxyURL.Scheme == "https")
|
||||
proxyAddr := addrFromURL(cfg.ProxyURL)
|
||||
|
||||
switch {
|
||||
case cfg.ProxyURL.Scheme == schemeSocks5:
|
||||
var auth *proxy.Auth
|
||||
if u := cfg.ProxyURL.User; u != nil {
|
||||
auth = &proxy.Auth{User: u.Username()}
|
||||
auth.Password, _ = u.Password()
|
||||
}
|
||||
|
||||
// SOCKS5 implementation do not return errors.
|
||||
socksDialer, _ := proxy.SOCKS5("tcp", proxyAddr, auth, proxyDialer)
|
||||
return dialerFunc(func(network, targetAddr string) (net.Conn, error) {
|
||||
co, err := socksDialer.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.TLS {
|
||||
c := &tls.Config{}
|
||||
if tlsConfig != nil {
|
||||
c = tlsConfig.Clone()
|
||||
}
|
||||
|
||||
if c.ServerName == "" {
|
||||
host, _, _ := net.SplitHostPort(targetAddr)
|
||||
c.ServerName = host
|
||||
}
|
||||
|
||||
return tls.Client(co, c), nil
|
||||
}
|
||||
|
||||
return co, nil
|
||||
})
|
||||
case cfg.HTTP && !cfg.TLS:
|
||||
// Nothing to do the Proxy-Authorization header will be added by the ReverseProxy.
|
||||
|
||||
default:
|
||||
hdr := make(http.Header)
|
||||
if u := cfg.ProxyURL.User; u != nil {
|
||||
username := u.Username()
|
||||
password, _ := u.Password()
|
||||
auth := username + ":" + password
|
||||
hdr.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||
}
|
||||
|
||||
return dialerFunc(func(network, targetAddr string) (net.Conn, error) {
|
||||
conn, err := proxyDialer.Dial("tcp", proxyAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connectReq := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Opaque: targetAddr},
|
||||
Host: targetAddr,
|
||||
Header: hdr,
|
||||
}
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
|
||||
var resp *http.Response
|
||||
|
||||
// Write the CONNECT request & read the response.
|
||||
go func() {
|
||||
defer close(didReadResponse)
|
||||
err = connectReq.Write(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Okay to use and discard buffered reader here, because
|
||||
// TLS server will not speak until spoken to.
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err = http.ReadResponse(br, connectReq)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-connectCtx.Done():
|
||||
conn.Close()
|
||||
<-didReadResponse
|
||||
return nil, connectCtx.Err()
|
||||
case <-didReadResponse:
|
||||
// resp or err now set
|
||||
}
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_, statusText, ok := strings.Cut(resp.Status, " ")
|
||||
conn.Close()
|
||||
if !ok {
|
||||
return nil, errors.New("unknown status code")
|
||||
}
|
||||
|
||||
return nil, errors.New(statusText)
|
||||
}
|
||||
|
||||
c := &tls.Config{}
|
||||
if tlsConfig != nil {
|
||||
c = tlsConfig.Clone()
|
||||
}
|
||||
if c.ServerName == "" {
|
||||
host, _, _ := net.SplitHostPort(targetAddr)
|
||||
c.ServerName = host
|
||||
}
|
||||
|
||||
return tls.Client(conn, c), nil
|
||||
})
|
||||
}
|
||||
return dialerFunc(func(network, addr string) (net.Conn, error) {
|
||||
return proxyDialer.Dial("tcp", proxyAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func buildDialer(cfg dialerConfig, tlsConfig *tls.Config, isTLS bool) dialer {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: cfg.DialTimeout,
|
||||
KeepAlive: cfg.DialKeepAlive,
|
||||
}
|
||||
|
||||
if !isTLS {
|
||||
return dialer
|
||||
}
|
||||
|
||||
return &tls.Dialer{
|
||||
NetDialer: dialer,
|
||||
Config: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func addrFromURL(u *url.URL) string {
|
||||
addr := u.Host
|
||||
|
||||
if u.Port() == "" {
|
||||
switch u.Scheme {
|
||||
case schemeHTTP:
|
||||
return addr + ":80"
|
||||
case schemeHTTPS:
|
||||
return addr + ":443"
|
||||
}
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
553
pkg/proxy/fast/proxy.go
Normal file
553
pkg/proxy/fast/proxy.go
Normal file
|
@ -0,0 +1,553 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 32 * 1024
|
||||
bufioSize = 64 * 1024
|
||||
)
|
||||
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
type pool[T any] struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func (p *pool[T]) Get() T {
|
||||
if tmp := p.pool.Get(); tmp != nil {
|
||||
return tmp.(T)
|
||||
}
|
||||
|
||||
var res T
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *pool[T]) Put(x T) {
|
||||
p.pool.Put(x)
|
||||
}
|
||||
|
||||
type buffConn struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b buffConn) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
type writeDetector struct {
|
||||
net.Conn
|
||||
|
||||
written bool
|
||||
}
|
||||
|
||||
func (w *writeDetector) Write(p []byte) (int, error) {
|
||||
n, err := w.Conn.Write(p)
|
||||
if n > 0 {
|
||||
w.written = true
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
type writeFlusher struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w *writeFlusher) Write(b []byte) (int, error) {
|
||||
n, err := w.Writer.Write(b)
|
||||
if f, ok := w.Writer.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
type timeoutError struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (t timeoutError) Timeout() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t timeoutError) Temporary() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ReverseProxy is the FastProxy reverse proxy implementation.
|
||||
type ReverseProxy struct {
|
||||
debug bool
|
||||
|
||||
connPool *connPool
|
||||
|
||||
bufferPool pool[[]byte]
|
||||
readerPool pool[*bufio.Reader]
|
||||
writerPool pool[*bufio.Writer]
|
||||
limitReaderPool pool[*io.LimitedReader]
|
||||
|
||||
proxyAuth string
|
||||
|
||||
targetURL *url.URL
|
||||
passHostHeader bool
|
||||
responseHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewReverseProxy creates a new ReverseProxy.
|
||||
func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
|
||||
var proxyAuth string
|
||||
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
return &ReverseProxy{
|
||||
debug: debug,
|
||||
passHostHeader: passHostHeader,
|
||||
targetURL: targetURL,
|
||||
proxyAuth: proxyAuth,
|
||||
connPool: connPool,
|
||||
responseHeaderTimeout: responseHeaderTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close()
|
||||
}
|
||||
|
||||
outReq := fasthttp.AcquireRequest()
|
||||
defer fasthttp.ReleaseRequest(outReq)
|
||||
|
||||
// This is not required as the headers are already normalized by net/http.
|
||||
outReq.Header.DisableNormalizing()
|
||||
|
||||
for k, v := range req.Header {
|
||||
for _, s := range v {
|
||||
outReq.Header.Add(k, s)
|
||||
}
|
||||
}
|
||||
|
||||
removeConnectionHeaders(&outReq.Header)
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
outReq.Header.Del(header)
|
||||
}
|
||||
|
||||
if p.proxyAuth != "" {
|
||||
outReq.Header.Set("Proxy-Authorization", p.proxyAuth)
|
||||
}
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
outReq.Header.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
if p.debug {
|
||||
outReq.Header.Set("X-Traefik-Fast-Proxy", "enabled")
|
||||
}
|
||||
|
||||
reqUpType := upgradeType(req.Header)
|
||||
if !isGraphic(reqUpType) {
|
||||
proxyhttputil.ErrorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
if reqUpType != "" {
|
||||
outReq.Header.Set("Connection", "Upgrade")
|
||||
outReq.Header.Set("Upgrade", reqUpType)
|
||||
if reqUpType == "websocket" {
|
||||
cleanWebSocketHeaders(&outReq.Header)
|
||||
}
|
||||
}
|
||||
|
||||
u2 := new(url.URL)
|
||||
*u2 = *req.URL
|
||||
u2.Scheme = p.targetURL.Scheme
|
||||
u2.Host = p.targetURL.Host
|
||||
|
||||
u := req.URL
|
||||
if req.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(req.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
u2.Path = u.Path
|
||||
u2.RawPath = u.RawPath
|
||||
u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||||
|
||||
outReq.SetHost(u2.Host)
|
||||
outReq.Header.SetHost(u2.Host)
|
||||
|
||||
if p.passHostHeader {
|
||||
outReq.Header.SetHost(req.Host)
|
||||
}
|
||||
|
||||
outReq.SetRequestURI(u2.RequestURI())
|
||||
|
||||
outReq.SetBodyStream(req.Body, int(req.ContentLength))
|
||||
|
||||
outReq.Header.SetMethod(req.Method)
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
prior, ok := req.Header["X-Forwarded-For"]
|
||||
if len(prior) > 0 {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
|
||||
omit := ok && prior == nil // Go Issue 38079: nil now means don't populate the header
|
||||
if !omit {
|
||||
outReq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.roundTrip(rw, req, outReq, reqUpType); err != nil {
|
||||
proxyhttputil.ErrorHandler(rw, req, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Note that unlike the net/http RoundTrip:
|
||||
// - we are not supporting "100 Continue" response to forward them as-is to the client.
|
||||
// - we are not asking for compressed response automatically. That is because this will add an extra cost when the
|
||||
// client is asking for an uncompressed response, as we will have to un-compress it, and nowadays most clients are
|
||||
// already asking for compressed response (allowing "passthrough" compression).
|
||||
func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outReq *fasthttp.Request, reqUpType string) error {
|
||||
ctx := req.Context()
|
||||
trace := httptrace.ContextClientTrace(ctx)
|
||||
|
||||
var co *conn
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
var err error
|
||||
co, err = p.connPool.AcquireConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquire connection: %w", err)
|
||||
}
|
||||
|
||||
wd := &writeDetector{Conn: co}
|
||||
|
||||
err = p.writeRequest(wd, outReq)
|
||||
if wd.written && trace != nil && trace.WroteRequest != nil {
|
||||
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
|
||||
trace.WroteRequest(httptrace.WroteRequestInfo{})
|
||||
}
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug().Err(err).Msg("Error while writing request")
|
||||
|
||||
co.Close()
|
||||
|
||||
if wd.written && !isReplayable(req) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
br := p.readerPool.Get()
|
||||
if br == nil {
|
||||
br = bufio.NewReaderSize(co, bufioSize)
|
||||
}
|
||||
defer p.readerPool.Put(br)
|
||||
|
||||
br.Reset(co)
|
||||
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(res)
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
for {
|
||||
var timer *time.Timer
|
||||
errTimeout := atomic.Pointer[timeoutError]{}
|
||||
if p.responseHeaderTimeout > 0 {
|
||||
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
|
||||
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
||||
co.Close()
|
||||
})
|
||||
}
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.Read(br); err != nil {
|
||||
if p.responseHeaderTimeout > 0 {
|
||||
if errT := errTimeout.Load(); errT != nil {
|
||||
return errT
|
||||
}
|
||||
}
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
fixPragmaCacheControl(&res.Header)
|
||||
|
||||
resCode := res.StatusCode()
|
||||
is1xx := 100 <= resCode && resCode <= 199
|
||||
// treat 101 as a terminal status, see issue 26161
|
||||
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||||
if is1xxNonTerminal {
|
||||
removeConnectionHeaders(&res.Header)
|
||||
h := rw.Header()
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
rw.WriteHeader(res.StatusCode())
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
|
||||
res.Reset()
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
announcedTrailers := res.Header.Peek("Trailer")
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode() == http.StatusSwitchingProtocols {
|
||||
// As the connection has been hijacked, it cannot be added back to the pool.
|
||||
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
|
||||
return nil
|
||||
}
|
||||
|
||||
removeConnectionHeaders(&res.Header)
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
if len(announcedTrailers) > 0 {
|
||||
res.Header.Add("Trailer", string(announcedTrailers))
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
rw.WriteHeader(res.StatusCode())
|
||||
|
||||
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
||||
if res.Header.ContentLength() == -1 {
|
||||
cbr := httputil.NewChunkedReader(br)
|
||||
|
||||
b := p.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer p.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.ReadTrailer(br); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Header.Len() > 0 {
|
||||
var announcedTrailersKey []string
|
||||
if len(announcedTrailers) > 0 {
|
||||
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
for _, s := range announcedTrailersKey {
|
||||
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
|
||||
})
|
||||
}
|
||||
|
||||
p.connPool.ReleaseConn(co)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
brl := p.limitReaderPool.Get()
|
||||
if brl == nil {
|
||||
brl = &io.LimitedReader{}
|
||||
}
|
||||
defer p.limitReaderPool.Put(brl)
|
||||
|
||||
brl.R = br
|
||||
brl.N = int64(res.Header.ContentLength())
|
||||
|
||||
b := p.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer p.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
p.connPool.ReleaseConn(co)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) writeRequest(co net.Conn, outReq *fasthttp.Request) error {
|
||||
bw := p.writerPool.Get()
|
||||
if bw == nil {
|
||||
bw = bufio.NewWriterSize(co, bufioSize)
|
||||
}
|
||||
defer p.writerPool.Put(bw)
|
||||
|
||||
bw.Reset(co)
|
||||
|
||||
if err := outReq.Write(bw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bw.Flush()
|
||||
}
|
||||
|
||||
// isReplayable returns whether the request is replayable.
|
||||
func isReplayable(req *http.Request) bool {
|
||||
if req.Body == nil || req.Body == http.NoBody {
|
||||
switch req.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
return true
|
||||
}
|
||||
|
||||
// The Idempotency-Key, while non-standard, is widely used to
|
||||
// mean a POST or other request is idempotent. See
|
||||