Detect and drop broken conns in the fastproxy pool

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Kevin Pollet 2024-10-25 14:26:04 +02:00 committed by GitHub
parent b22e081c7c
commit e3ed52ba7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 426 additions and 280 deletions

View file

@ -79,18 +79,13 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader,
return nil, fmt.Errorf("getting ServersTransport: %w", err) return nil, fmt.Errorf("getting ServersTransport: %w", err)
} }
var responseHeaderTimeout time.Duration
if cfg.ForwardingTimeouts != nil {
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
}
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName) tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting TLS config: %w", err) return nil, fmt.Errorf("getting TLS config: %w", err)
} }
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL) pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, responseHeaderTimeout, pool) return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, pool)
} }
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool { func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
@ -106,9 +101,11 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
idleConnTimeout := 90 * time.Second idleConnTimeout := 90 * time.Second
dialTimeout := 30 * time.Second dialTimeout := 30 * time.Second
var responseHeaderTimeout time.Duration
if config.ForwardingTimeouts != nil { if config.ForwardingTimeouts != nil {
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout) idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout) dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
responseHeaderTimeout = time.Duration(config.ForwardingTimeouts.ResponseHeaderTimeout)
} }
proxyDialer := newDialer(dialerConfig{ proxyDialer := newDialer(dialerConfig{
@ -119,7 +116,7 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
ProxyURL: proxyURL, ProxyURL: proxyURL,
}, tlsConfig) }, tlsConfig)
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) { connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, responseHeaderTimeout, func() (net.Conn, error) {
return proxyDialer.Dial("tcp", addrFromURL(targetURL)) return proxyDialer.Dial("tcp", addrFromURL(targetURL))
}) })

View file

@ -1,42 +1,309 @@
package fast package fast
import ( import (
"bufio"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"net/http/httputil"
"strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
) )
// rwWithUpgrade contains a ResponseWriter and an upgradeHandler,
// used to upgrade the connection (e.g. Websockets).
type rwWithUpgrade struct {
RW http.ResponseWriter
Upgrade upgradeHandler
}
// conn is an enriched net.Conn. // conn is an enriched net.Conn.
type conn struct { type conn struct {
net.Conn net.Conn
RWCh chan rwWithUpgrade
ErrCh chan error
br *bufio.Reader
idleAt time.Time // the last time it was marked as idle. idleAt time.Time // the last time it was marked as idle.
idleTimeout time.Duration idleTimeout time.Duration
responseHeaderTimeout time.Duration
expectedResponse atomic.Bool
broken atomic.Bool
upgraded atomic.Bool
closeMu sync.Mutex
closed bool
closeErr error
bufferPool *pool[[]byte]
limitedReaderPool *pool[*io.LimitedReader]
} }
func (c *conn) isExpired() bool { // Read reads data from the connection.
// Overrides conn Read to use the buffered reader.
func (c *conn) Read(b []byte) (n int, err error) {
return c.br.Read(b)
}
// Close closes the connection.
// Ensures that connection is closed only once,
// to avoid duplicate close error.
func (c *conn) Close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.closed {
return c.closeErr
}
c.closed = true
c.closeErr = c.Conn.Close()
return c.closeErr
}
// isStale returns whether the connection is in an invalid state (i.e. expired/broken).
func (c *conn) isStale() bool {
expTime := c.idleAt.Add(c.idleTimeout) expTime := c.idleAt.Add(c.idleTimeout)
return c.idleTimeout > 0 && time.Now().After(expTime) return c.idleTimeout > 0 && time.Now().After(expTime) || c.broken.Load()
}
// isUpgraded returns whether this connection has been upgraded (e.g. Websocket).
// An upgraded connection should not be reused and putted back in the connection pool.
func (c *conn) isUpgraded() bool {
return c.upgraded.Load()
}
// readLoop handles the successive HTTP response read operations on the connection,
// and watches for unsolicited bytes or connection errors when idle.
func (c *conn) readLoop() {
defer c.Close()
for {
_, err := c.br.Peek(1)
if err != nil {
select {
// An error occurred while a response was expected to be handled.
case <-c.RWCh:
c.ErrCh <- err
// An error occurred on an idle connection.
default:
c.broken.Store(true)
}
return
}
// Unsolicited response received on an idle connection.
if !c.expectedResponse.Load() {
c.broken.Store(true)
return
}
r := <-c.RWCh
if err = c.handleResponse(r); err != nil {
c.ErrCh <- err
return
}
c.expectedResponse.Store(false)
c.ErrCh <- nil
}
}
func (c *conn) handleResponse(r rwWithUpgrade) error {
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.Header.SetNoDefaultContentType(true)
for {
var (
timer *time.Timer
errTimeout atomic.Pointer[timeoutError]
)
if c.responseHeaderTimeout > 0 {
timer = time.AfterFunc(c.responseHeaderTimeout, func() {
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
c.Close() // This close call is needed to interrupt the read operation below when the timeout is over.
})
}
res.Header.SetNoDefaultContentType(true)
if err := res.Header.Read(c.br); err != nil {
if c.responseHeaderTimeout > 0 {
if errT := errTimeout.Load(); errT != nil {
return errT
}
}
return err
}
if timer != nil {
timer.Stop()
}
fixPragmaCacheControl(&res.Header)
resCode := res.StatusCode()
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
if is1xxNonTerminal {
removeConnectionHeaders(&res.Header)
h := r.RW.Header()
for _, header := range hopHeaders {
res.Header.Del(header)
}
res.Header.VisitAll(func(key, value []byte) {
r.RW.Header().Add(string(key), string(value))
})
r.RW.WriteHeader(res.StatusCode())
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
res.Reset()
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
continue
}
break
}
announcedTrailers := res.Header.Peek("Trailer")
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode() == http.StatusSwitchingProtocols {
r.Upgrade(r.RW, res, c)
c.upgraded.Store(true) // As the connection has been upgraded, it cannot be added back to the pool.
return nil
}
removeConnectionHeaders(&res.Header)
for _, header := range hopHeaders {
res.Header.Del(header)
}
if len(announcedTrailers) > 0 {
res.Header.Add("Trailer", string(announcedTrailers))
}
res.Header.VisitAll(func(key, value []byte) {
r.RW.Header().Add(string(key), string(value))
})
r.RW.WriteHeader(res.StatusCode())
if res.Header.ContentLength() == 0 {
return nil
}
// When a body is not allowed for a given status code the body is ignored.
// The connection will be marked as broken by the next Peek in the readloop.
if !isBodyAllowedForStatus(res.StatusCode()) {
return nil
}
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
if res.Header.ContentLength() == -1 {
cbr := httputil.NewChunkedReader(c.br)
b := c.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer c.bufferPool.Put(b)
if _, err := io.CopyBuffer(&writeFlusher{r.RW}, cbr, b); err != nil {
return err
}
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
if err := res.Header.ReadTrailer(c.br); err != nil {
return err
}
if res.Header.Len() > 0 {
var announcedTrailersKey []string
if len(announcedTrailers) > 0 {
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
}
res.Header.VisitAll(func(key, value []byte) {
for _, s := range announcedTrailersKey {
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
r.RW.Header().Add(string(key), string(value))
return
}
}
r.RW.Header().Add(http.TrailerPrefix+string(key), string(value))
})
}
return nil
}
brl := c.limitedReaderPool.Get()
if brl == nil {
brl = &io.LimitedReader{}
}
defer c.limitedReaderPool.Put(brl)
brl.R = c.br
brl.N = int64(res.Header.ContentLength())
b := c.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer c.bufferPool.Put(b)
if _, err := io.CopyBuffer(r.RW, brl, b); err != nil {
return err
}
return nil
} }
// connPool is a net.Conn pool implementation using channels. // connPool is a net.Conn pool implementation using channels.
type connPool struct { type connPool struct {
dialer func() (net.Conn, error) dialer func() (net.Conn, error)
idleConns chan *conn idleConns chan *conn
idleConnTimeout time.Duration idleConnTimeout time.Duration
ticker *time.Ticker responseHeaderTimeout time.Duration
doneCh chan struct{} ticker *time.Ticker
bufferPool pool[[]byte]
limitedReaderPool pool[*io.LimitedReader]
doneCh chan struct{}
} }
// newConnPool creates a new connPool. // newConnPool creates a new connPool.
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool { func newConnPool(maxIdleConn int, idleConnTimeout, responseHeaderTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
c := &connPool{ c := &connPool{
dialer: dialer, dialer: dialer,
idleConns: make(chan *conn, maxIdleConn), idleConns: make(chan *conn, maxIdleConn),
idleConnTimeout: idleConnTimeout, idleConnTimeout: idleConnTimeout,
doneCh: make(chan struct{}), responseHeaderTimeout: responseHeaderTimeout,
doneCh: make(chan struct{}),
} }
if idleConnTimeout > 0 { if idleConnTimeout > 0 {
@ -72,22 +339,28 @@ func (c *connPool) AcquireConn() (*conn, error) {
return nil, err return nil, err
} }
if !co.isExpired() { if !co.isStale() {
return co, nil return co, nil
} }
// As the acquired conn is expired we can close it // As the acquired conn is stale we can close it
// without putting it again into the pool. // without putting it again into the pool.
if err := co.Close(); err != nil { if err := co.Close(); err != nil {
log.Debug(). log.Debug().
Err(err). Err(err).
Msg("Unexpected error while releasing the connection") Msg("Unexpected error while closing the connection")
} }
} }
} }
// ReleaseConn releases the given net.Conn to the pool. // ReleaseConn releases the given net.Conn to the pool.
func (c *connPool) ReleaseConn(co *conn) { func (c *connPool) ReleaseConn(co *conn) {
// An upgraded connection cannot be safely reused for another roundTrip,
// thus we are not putting it back to the pool.
if co.isUpgraded() {
return
}
co.idleAt = time.Now() co.idleAt = time.Now()
c.releaseConn(co) c.releaseConn(co)
} }
@ -97,7 +370,7 @@ func (c *connPool) cleanIdleConns() {
for { for {
select { select {
case co := <-c.idleConns: case co := <-c.idleConns:
if !co.isExpired() { if !co.isStale() {
c.releaseConn(co) c.releaseConn(co)
return return
} }
@ -105,7 +378,7 @@ func (c *connPool) cleanIdleConns() {
if err := co.Close(); err != nil { if err := co.Close(); err != nil {
log.Debug(). log.Debug().
Err(err). Err(err).
Msg("Unexpected error while releasing the connection") Msg("Unexpected error while closing the connection")
} }
default: default:
@ -155,9 +428,33 @@ func (c *connPool) askForNewConn(errCh chan<- error) {
return return
} }
c.releaseConn(&conn{ newConn := &conn{
Conn: co, Conn: co,
idleAt: time.Now(), br: bufio.NewReaderSize(co, bufioSize),
idleTimeout: c.idleConnTimeout, idleAt: time.Now(),
}) idleTimeout: c.idleConnTimeout,
responseHeaderTimeout: c.responseHeaderTimeout,
RWCh: make(chan rwWithUpgrade),
ErrCh: make(chan error),
bufferPool: &c.bufferPool,
limitedReaderPool: &c.limitedReaderPool,
}
go newConn.readLoop()
c.releaseConn(newConn)
}
// isBodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 7230, section 3.3.
// From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L459
func isBodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
} }

View file

@ -58,7 +58,7 @@ func TestConnPool_ConnReuse(t *testing.T) {
return &net.TCPConn{}, nil return &net.TCPConn{}, nil
} }
pool := newConnPool(2, 0, dialer) pool := newConnPool(2, 0, 0, dialer)
test.poolFn(pool) test.poolFn(pool)
assert.Equal(t, test.expected, connAlloc) assert.Equal(t, test.expected, connAlloc)
@ -102,13 +102,16 @@ func TestConnPool_MaxIdleConn(t *testing.T) {
var keepOpenedConn int var keepOpenedConn int
dialer := func() (net.Conn, error) { dialer := func() (net.Conn, error) {
keepOpenedConn++ keepOpenedConn++
return &mockConn{closeFn: func() error { return &mockConn{
keepOpenedConn-- doneCh: make(chan struct{}),
return nil closeFn: func() error {
}}, nil keepOpenedConn--
return nil
},
}, nil
} }
pool := newConnPool(test.maxIdleConn, 0, dialer) pool := newConnPool(test.maxIdleConn, 0, 0, dialer)
test.poolFn(pool) test.poolFn(pool)
assert.Equal(t, test.expected, keepOpenedConn) assert.Equal(t, test.expected, keepOpenedConn)
@ -129,7 +132,7 @@ func TestGC(t *testing.T) {
return c, nil return c, nil
} }
pools["test"] = newConnPool(10, 1*time.Second, dialer) pools["test"] = newConnPool(10, 1*time.Second, 0, dialer)
runtime.SetFinalizer(pools["test"], func(p *connPool) { runtime.SetFinalizer(pools["test"], func(p *connPool) {
isDestroyed = true isDestroyed = true
}) })
@ -149,10 +152,12 @@ func TestGC(t *testing.T) {
type mockConn struct { type mockConn struct {
closeFn func() error closeFn func() error
doneCh chan struct{} // makes sure that the readLoop is blocking avoiding close.
} }
func (m *mockConn) Read(_ []byte) (n int, err error) { func (m *mockConn) Read(_ []byte) (n int, err error) {
panic("implement me") <-m.doneCh
return 0, nil
} }
func (m *mockConn) Write(_ []byte) (n int, err error) { func (m *mockConn) Write(_ []byte) (n int, err error) {
@ -160,6 +165,7 @@ func (m *mockConn) Write(_ []byte) (n int, err error) {
} }
func (m *mockConn) Close() error { func (m *mockConn) Close() error {
defer close(m.doneCh)
if m.closeFn != nil { if m.closeFn != nil {
return m.closeFn() return m.closeFn()
} }

View file

@ -4,18 +4,14 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil" proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
@ -57,15 +53,6 @@ func (p *pool[T]) Put(x T) {
p.pool.Put(x) p.pool.Put(x)
} }
type buffConn struct {
*bufio.Reader
net.Conn
}
func (b buffConn) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}
type writeDetector struct { type writeDetector struct {
net.Conn net.Conn
@ -112,21 +99,17 @@ type ReverseProxy struct {
connPool *connPool connPool *connPool
bufferPool pool[[]byte] writerPool pool[*bufio.Writer]
readerPool pool[*bufio.Reader]
writerPool pool[*bufio.Writer]
limitReaderPool pool[*io.LimitedReader]
proxyAuth string proxyAuth string
targetURL *url.URL targetURL *url.URL
passHostHeader bool passHostHeader bool
preservePath bool preservePath bool
responseHeaderTimeout time.Duration
} }
// NewReverseProxy creates a new ReverseProxy. // NewReverseProxy creates a new ReverseProxy.
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) { func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, connPool *connPool) (*ReverseProxy, error) {
var proxyAuth string var proxyAuth string
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" { if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
username := proxyURL.User.Username() username := proxyURL.User.Username()
@ -135,13 +118,12 @@ func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preser
} }
return &ReverseProxy{ return &ReverseProxy{
debug: debug, debug: debug,
passHostHeader: passHostHeader, passHostHeader: passHostHeader,
preservePath: preservePath, preservePath: preservePath,
targetURL: targetURL, targetURL: targetURL,
proxyAuth: proxyAuth, proxyAuth: proxyAuth,
connPool: connPool, connPool: connPool,
responseHeaderTimeout: responseHeaderTimeout,
}, nil }, nil
} }
@ -273,8 +255,15 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
return fmt.Errorf("acquire connection: %w", err) return fmt.Errorf("acquire connection: %w", err)
} }
// Before writing the request,
// we mark the conn as expecting to handle a response.
co.expectedResponse.Store(true)
wd := &writeDetector{Conn: co} wd := &writeDetector{Conn: co}
// TODO: do not wait to write the full request before reading the response (to handle "100 Continue").
// TODO: this is currently impossible with fasthttp to write the request partially (headers only).
// Currently, writing the request fully is a mandatory step before handling the response.
err = p.writeRequest(wd, outReq) err = p.writeRequest(wd, outReq)
if wd.written && trace != nil && trace.WroteRequest != nil { if wd.written && trace != nil && trace.WroteRequest != nil {
// WroteRequest hook is used by the tracing middleware to detect if the request has been written. // WroteRequest hook is used by the tracing middleware to detect if the request has been written.
@ -293,169 +282,17 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
} }
} }
br := p.readerPool.Get() // Sending the responseWriter unlocks the connection readLoop, to handle the response.
if br == nil { co.RWCh <- rwWithUpgrade{
br = bufio.NewReaderSize(co, bufioSize) RW: rw,
} Upgrade: upgradeResponseHandler(req.Context(), reqUpType),
defer p.readerPool.Put(br)
br.Reset(co)
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.Header.SetNoDefaultContentType(true)
for {
var timer *time.Timer
errTimeout := atomic.Pointer[timeoutError]{}
if p.responseHeaderTimeout > 0 {
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
co.Close()
})
}
res.Header.SetNoDefaultContentType(true)
if err := res.Header.Read(br); err != nil {
if p.responseHeaderTimeout > 0 {
if errT := errTimeout.Load(); errT != nil {
return errT
}
}
co.Close()
return err
}
if timer != nil {
timer.Stop()
}
fixPragmaCacheControl(&res.Header)
resCode := res.StatusCode()
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
if is1xxNonTerminal {
removeConnectionHeaders(&res.Header)
h := rw.Header()
for _, header := range hopHeaders {
res.Header.Del(header)
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
res.Reset()
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
continue
}
break
} }
announcedTrailers := res.Header.Peek("Trailer") if err := <-co.ErrCh; err != nil {
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode() == http.StatusSwitchingProtocols {
// As the connection has been hijacked, it cannot be added back to the pool.
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
return nil
}
removeConnectionHeaders(&res.Header)
for _, header := range hopHeaders {
res.Header.Del(header)
}
if len(announcedTrailers) > 0 {
res.Header.Add("Trailer", string(announcedTrailers))
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
if res.Header.ContentLength() == -1 {
cbr := httputil.NewChunkedReader(br)
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
co.Close()
return err
}
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
if err := res.Header.ReadTrailer(br); err != nil {
co.Close()
return err
}
if res.Header.Len() > 0 {
var announcedTrailersKey []string
if len(announcedTrailers) > 0 {
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
}
res.Header.VisitAll(func(key, value []byte) {
for _, s := range announcedTrailersKey {
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
rw.Header().Add(string(key), string(value))
return
}
}
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
})
}
p.connPool.ReleaseConn(co)
return nil
}
brl := p.limitReaderPool.Get()
if brl == nil {
brl = &io.LimitedReader{}
}
defer p.limitReaderPool.Put(brl)
brl.R = br
brl.N = int64(res.Header.ContentLength())
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
co.Close()
return err return err
} }
p.connPool.ReleaseConn(co) p.connPool.ReleaseConn(co)
return nil return nil
} }

View file

@ -362,7 +362,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
u := parseURI(t, srv.URL) u := parseURI(t, srv.URL)
f, err := NewReverseProxy(u, nil, true, false, false, 0, newConnPool(1, 0, func() (net.Conn, error) { f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
return net.Dial("tcp", u.Host) return net.Dial("tcp", u.Host)
})) }))
require.NoError(t, err) require.NoError(t, err)
@ -434,7 +434,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
defer srv.Close() defer srv.Close()
u := parseURI(t, srv.URL) u := parseURI(t, srv.URL)
f, err := NewReverseProxy(u, nil, true, false, false, 0, newConnPool(1, 0, func() (net.Conn, error) { f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
return net.Dial("tcp", u.Host) return net.Dial("tcp", u.Host)
})) }))
require.NoError(t, err) require.NoError(t, err)
@ -663,7 +663,7 @@ func parseURI(t *testing.T, uri string) *url.URL {
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool { func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
u := testhelpers.MustParseURL(target) u := testhelpers.MustParseURL(target)
return newConnPool(200, 0, func() (net.Conn, error) { return newConnPool(200, 0, 0, func() (net.Conn, error) {
if tlsConfig != nil { if tlsConfig != nil {
return tls.Dial("tcp", u.Host, tlsConfig) return tls.Dial("tcp", u.Host, tlsConfig)
} }
@ -676,7 +676,7 @@ func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptes
t.Helper() t.Helper()
u := parseURI(t, uri) u := parseURI(t, uri)
proxy, err := NewReverseProxy(u, nil, false, true, false, 0, pool) proxy, err := NewReverseProxy(u, nil, false, true, false, pool)
require.NoError(t, err) require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {

View file

@ -2,6 +2,7 @@ package fast
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -19,72 +20,75 @@ type switchProtocolCopier struct {
user, backend io.ReadWriter user, backend io.ReadWriter
} }
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { func (c switchProtocolCopier) copyFromBackend(errCh chan<- error) {
_, err := io.Copy(c.user, c.backend) _, err := io.Copy(c.user, c.backend)
errc <- err errCh <- err
} }
func (c switchProtocolCopier) copyToBackend(errc chan<- error) { func (c switchProtocolCopier) copyToBackend(errCh chan<- error) {
_, err := io.Copy(c.backend, c.user) _, err := io.Copy(c.backend, c.user)
errc <- err errCh <- err
} }
func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) { type upgradeHandler func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn)
defer backConn.Close()
resUpType := upgradeTypeFastHTTP(&res.Header) func upgradeResponseHandler(ctx context.Context, reqUpType string) upgradeHandler {
return func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) {
resUpType := upgradeTypeFastHTTP(&res.Header)
if !strings.EqualFold(reqUpType, resUpType) { if !strings.EqualFold(reqUpType, resUpType) {
httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
return backConn.Close()
} return
hj, ok := rw.(http.Hijacker)
if !ok {
httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
} }
_ = backConn.Close()
}()
defer close(backConnCloseCh) hj, ok := rw.(http.Hijacker)
if !ok {
conn, brw, err := hj.Hijack() httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
if err != nil { backConn.Close()
httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err)) return
return
}
defer conn.Close()
for k, values := range rw.Header() {
for _, v := range values {
res.Header.Add(k, v)
} }
} backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-ctx.Done():
case <-backConnCloseCh:
}
_ = backConn.Close()
}()
defer close(backConnCloseCh)
if err := res.Header.Write(brw.Writer); err != nil { conn, brw, err := hj.Hijack()
httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err)) if err != nil {
return httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("hijack failed on protocol switch: %w", err))
} return
}
defer conn.Close()
if err := brw.Flush(); err != nil { for k, values := range rw.Header() {
httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err)) for _, v := range values {
return res.Header.Add(k, v)
} }
}
errc := make(chan error, 1) if err := res.Header.Write(brw.Writer); err != nil {
spc := switchProtocolCopier{user: conn, backend: backConn} httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response write: %w", err))
go spc.copyToBackend(errc) return
go spc.copyFromBackend(errc) }
<-errc
if err := brw.Flush(); err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response flush: %w", err))
return
}
errCh := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errCh)
go spc.copyFromBackend(errCh)
<-errCh
}
} }
func upgradeType(h http.Header) string { func upgradeType(h http.Header) string {

View file

@ -102,9 +102,14 @@ func isWebSocketUpgrade(req *http.Request) bool {
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request. // ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) { func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
ErrorHandlerWithContext(req.Context(), w, err)
}
// ErrorHandlerWithContext is the http.Handler called when something goes wrong when forwarding the request.
func ErrorHandlerWithContext(ctx context.Context, w http.ResponseWriter, err error) {
statusCode := ComputeStatusCode(err) statusCode := ComputeStatusCode(err)
logger := log.Ctx(req.Context()) logger := log.Ctx(ctx)
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode)) logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
w.WriteHeader(statusCode) w.WriteHeader(statusCode)