Detect and drop broken conns in the fastproxy pool
Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
parent
b22e081c7c
commit
e3ed52ba7c
7 changed files with 426 additions and 280 deletions
|
@ -79,18 +79,13 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader,
|
||||||
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseHeaderTimeout time.Duration
|
|
||||||
if cfg.ForwardingTimeouts != nil {
|
|
||||||
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
|
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting TLS config: %w", err)
|
return nil, fmt.Errorf("getting TLS config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
|
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
|
||||||
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, responseHeaderTimeout, pool)
|
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, pool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
|
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
|
||||||
|
@ -106,9 +101,11 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
|
||||||
|
|
||||||
idleConnTimeout := 90 * time.Second
|
idleConnTimeout := 90 * time.Second
|
||||||
dialTimeout := 30 * time.Second
|
dialTimeout := 30 * time.Second
|
||||||
|
var responseHeaderTimeout time.Duration
|
||||||
if config.ForwardingTimeouts != nil {
|
if config.ForwardingTimeouts != nil {
|
||||||
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
|
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
|
||||||
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
|
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
|
||||||
|
responseHeaderTimeout = time.Duration(config.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyDialer := newDialer(dialerConfig{
|
proxyDialer := newDialer(dialerConfig{
|
||||||
|
@ -119,7 +116,7 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
}, tlsConfig)
|
}, tlsConfig)
|
||||||
|
|
||||||
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) {
|
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, responseHeaderTimeout, func() (net.Conn, error) {
|
||||||
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
|
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,42 +1,309 @@
|
||||||
package fast
|
package fast
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// rwWithUpgrade contains a ResponseWriter and an upgradeHandler,
|
||||||
|
// used to upgrade the connection (e.g. Websockets).
|
||||||
|
type rwWithUpgrade struct {
|
||||||
|
RW http.ResponseWriter
|
||||||
|
Upgrade upgradeHandler
|
||||||
|
}
|
||||||
|
|
||||||
// conn is an enriched net.Conn.
|
// conn is an enriched net.Conn.
|
||||||
type conn struct {
|
type conn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
|
RWCh chan rwWithUpgrade
|
||||||
|
ErrCh chan error
|
||||||
|
|
||||||
|
br *bufio.Reader
|
||||||
|
|
||||||
idleAt time.Time // the last time it was marked as idle.
|
idleAt time.Time // the last time it was marked as idle.
|
||||||
idleTimeout time.Duration
|
idleTimeout time.Duration
|
||||||
|
|
||||||
|
responseHeaderTimeout time.Duration
|
||||||
|
|
||||||
|
expectedResponse atomic.Bool
|
||||||
|
broken atomic.Bool
|
||||||
|
upgraded atomic.Bool
|
||||||
|
|
||||||
|
closeMu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
closeErr error
|
||||||
|
|
||||||
|
bufferPool *pool[[]byte]
|
||||||
|
limitedReaderPool *pool[*io.LimitedReader]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) isExpired() bool {
|
// Read reads data from the connection.
|
||||||
|
// Overrides conn Read to use the buffered reader.
|
||||||
|
func (c *conn) Read(b []byte) (n int, err error) {
|
||||||
|
return c.br.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
// Ensures that connection is closed only once,
|
||||||
|
// to avoid duplicate close error.
|
||||||
|
func (c *conn) Close() error {
|
||||||
|
c.closeMu.Lock()
|
||||||
|
defer c.closeMu.Unlock()
|
||||||
|
|
||||||
|
if c.closed {
|
||||||
|
return c.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
c.closed = true
|
||||||
|
c.closeErr = c.Conn.Close()
|
||||||
|
|
||||||
|
return c.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// isStale returns whether the connection is in an invalid state (i.e. expired/broken).
|
||||||
|
func (c *conn) isStale() bool {
|
||||||
expTime := c.idleAt.Add(c.idleTimeout)
|
expTime := c.idleAt.Add(c.idleTimeout)
|
||||||
return c.idleTimeout > 0 && time.Now().After(expTime)
|
return c.idleTimeout > 0 && time.Now().After(expTime) || c.broken.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUpgraded returns whether this connection has been upgraded (e.g. Websocket).
|
||||||
|
// An upgraded connection should not be reused and putted back in the connection pool.
|
||||||
|
func (c *conn) isUpgraded() bool {
|
||||||
|
return c.upgraded.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop handles the successive HTTP response read operations on the connection,
|
||||||
|
// and watches for unsolicited bytes or connection errors when idle.
|
||||||
|
func (c *conn) readLoop() {
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, err := c.br.Peek(1)
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
// An error occurred while a response was expected to be handled.
|
||||||
|
case <-c.RWCh:
|
||||||
|
c.ErrCh <- err
|
||||||
|
// An error occurred on an idle connection.
|
||||||
|
default:
|
||||||
|
c.broken.Store(true)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsolicited response received on an idle connection.
|
||||||
|
if !c.expectedResponse.Load() {
|
||||||
|
c.broken.Store(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r := <-c.RWCh
|
||||||
|
if err = c.handleResponse(r); err != nil {
|
||||||
|
c.ErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.expectedResponse.Store(false)
|
||||||
|
c.ErrCh <- nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) handleResponse(r rwWithUpgrade) error {
|
||||||
|
res := fasthttp.AcquireResponse()
|
||||||
|
defer fasthttp.ReleaseResponse(res)
|
||||||
|
|
||||||
|
res.Header.SetNoDefaultContentType(true)
|
||||||
|
|
||||||
|
for {
|
||||||
|
var (
|
||||||
|
timer *time.Timer
|
||||||
|
errTimeout atomic.Pointer[timeoutError]
|
||||||
|
)
|
||||||
|
if c.responseHeaderTimeout > 0 {
|
||||||
|
timer = time.AfterFunc(c.responseHeaderTimeout, func() {
|
||||||
|
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
||||||
|
c.Close() // This close call is needed to interrupt the read operation below when the timeout is over.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Header.SetNoDefaultContentType(true)
|
||||||
|
if err := res.Header.Read(c.br); err != nil {
|
||||||
|
if c.responseHeaderTimeout > 0 {
|
||||||
|
if errT := errTimeout.Load(); errT != nil {
|
||||||
|
return errT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
fixPragmaCacheControl(&res.Header)
|
||||||
|
|
||||||
|
resCode := res.StatusCode()
|
||||||
|
is1xx := 100 <= resCode && resCode <= 199
|
||||||
|
// treat 101 as a terminal status, see issue 26161
|
||||||
|
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||||||
|
if is1xxNonTerminal {
|
||||||
|
removeConnectionHeaders(&res.Header)
|
||||||
|
h := r.RW.Header()
|
||||||
|
|
||||||
|
for _, header := range hopHeaders {
|
||||||
|
res.Header.Del(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Header.VisitAll(func(key, value []byte) {
|
||||||
|
r.RW.Header().Add(string(key), string(value))
|
||||||
|
})
|
||||||
|
|
||||||
|
r.RW.WriteHeader(res.StatusCode())
|
||||||
|
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||||
|
for k := range h {
|
||||||
|
delete(h, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Reset()
|
||||||
|
res.Header.Reset()
|
||||||
|
res.Header.SetNoDefaultContentType(true)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
announcedTrailers := res.Header.Peek("Trailer")
|
||||||
|
|
||||||
|
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||||
|
if res.StatusCode() == http.StatusSwitchingProtocols {
|
||||||
|
r.Upgrade(r.RW, res, c)
|
||||||
|
c.upgraded.Store(true) // As the connection has been upgraded, it cannot be added back to the pool.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
removeConnectionHeaders(&res.Header)
|
||||||
|
|
||||||
|
for _, header := range hopHeaders {
|
||||||
|
res.Header.Del(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(announcedTrailers) > 0 {
|
||||||
|
res.Header.Add("Trailer", string(announcedTrailers))
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Header.VisitAll(func(key, value []byte) {
|
||||||
|
r.RW.Header().Add(string(key), string(value))
|
||||||
|
})
|
||||||
|
|
||||||
|
r.RW.WriteHeader(res.StatusCode())
|
||||||
|
|
||||||
|
if res.Header.ContentLength() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// When a body is not allowed for a given status code the body is ignored.
|
||||||
|
// The connection will be marked as broken by the next Peek in the readloop.
|
||||||
|
if !isBodyAllowedForStatus(res.StatusCode()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
||||||
|
if res.Header.ContentLength() == -1 {
|
||||||
|
cbr := httputil.NewChunkedReader(c.br)
|
||||||
|
|
||||||
|
b := c.bufferPool.Get()
|
||||||
|
if b == nil {
|
||||||
|
b = make([]byte, bufferSize)
|
||||||
|
}
|
||||||
|
defer c.bufferPool.Put(b)
|
||||||
|
|
||||||
|
if _, err := io.CopyBuffer(&writeFlusher{r.RW}, cbr, b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Header.Reset()
|
||||||
|
res.Header.SetNoDefaultContentType(true)
|
||||||
|
if err := res.Header.ReadTrailer(c.br); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Header.Len() > 0 {
|
||||||
|
var announcedTrailersKey []string
|
||||||
|
if len(announcedTrailers) > 0 {
|
||||||
|
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Header.VisitAll(func(key, value []byte) {
|
||||||
|
for _, s := range announcedTrailersKey {
|
||||||
|
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
||||||
|
r.RW.Header().Add(string(key), string(value))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RW.Header().Add(http.TrailerPrefix+string(key), string(value))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
brl := c.limitedReaderPool.Get()
|
||||||
|
if brl == nil {
|
||||||
|
brl = &io.LimitedReader{}
|
||||||
|
}
|
||||||
|
defer c.limitedReaderPool.Put(brl)
|
||||||
|
|
||||||
|
brl.R = c.br
|
||||||
|
brl.N = int64(res.Header.ContentLength())
|
||||||
|
|
||||||
|
b := c.bufferPool.Get()
|
||||||
|
if b == nil {
|
||||||
|
b = make([]byte, bufferSize)
|
||||||
|
}
|
||||||
|
defer c.bufferPool.Put(b)
|
||||||
|
|
||||||
|
if _, err := io.CopyBuffer(r.RW, brl, b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// connPool is a net.Conn pool implementation using channels.
|
// connPool is a net.Conn pool implementation using channels.
|
||||||
type connPool struct {
|
type connPool struct {
|
||||||
dialer func() (net.Conn, error)
|
dialer func() (net.Conn, error)
|
||||||
idleConns chan *conn
|
idleConns chan *conn
|
||||||
idleConnTimeout time.Duration
|
idleConnTimeout time.Duration
|
||||||
ticker *time.Ticker
|
responseHeaderTimeout time.Duration
|
||||||
doneCh chan struct{}
|
ticker *time.Ticker
|
||||||
|
bufferPool pool[[]byte]
|
||||||
|
limitedReaderPool pool[*io.LimitedReader]
|
||||||
|
doneCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newConnPool creates a new connPool.
|
// newConnPool creates a new connPool.
|
||||||
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
|
func newConnPool(maxIdleConn int, idleConnTimeout, responseHeaderTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
|
||||||
c := &connPool{
|
c := &connPool{
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
idleConns: make(chan *conn, maxIdleConn),
|
idleConns: make(chan *conn, maxIdleConn),
|
||||||
idleConnTimeout: idleConnTimeout,
|
idleConnTimeout: idleConnTimeout,
|
||||||
doneCh: make(chan struct{}),
|
responseHeaderTimeout: responseHeaderTimeout,
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if idleConnTimeout > 0 {
|
if idleConnTimeout > 0 {
|
||||||
|
@ -72,22 +339,28 @@ func (c *connPool) AcquireConn() (*conn, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !co.isExpired() {
|
if !co.isStale() {
|
||||||
return co, nil
|
return co, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// As the acquired conn is expired we can close it
|
// As the acquired conn is stale we can close it
|
||||||
// without putting it again into the pool.
|
// without putting it again into the pool.
|
||||||
if err := co.Close(); err != nil {
|
if err := co.Close(); err != nil {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Unexpected error while releasing the connection")
|
Msg("Unexpected error while closing the connection")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReleaseConn releases the given net.Conn to the pool.
|
// ReleaseConn releases the given net.Conn to the pool.
|
||||||
func (c *connPool) ReleaseConn(co *conn) {
|
func (c *connPool) ReleaseConn(co *conn) {
|
||||||
|
// An upgraded connection cannot be safely reused for another roundTrip,
|
||||||
|
// thus we are not putting it back to the pool.
|
||||||
|
if co.isUpgraded() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
co.idleAt = time.Now()
|
co.idleAt = time.Now()
|
||||||
c.releaseConn(co)
|
c.releaseConn(co)
|
||||||
}
|
}
|
||||||
|
@ -97,7 +370,7 @@ func (c *connPool) cleanIdleConns() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case co := <-c.idleConns:
|
case co := <-c.idleConns:
|
||||||
if !co.isExpired() {
|
if !co.isStale() {
|
||||||
c.releaseConn(co)
|
c.releaseConn(co)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -105,7 +378,7 @@ func (c *connPool) cleanIdleConns() {
|
||||||
if err := co.Close(); err != nil {
|
if err := co.Close(); err != nil {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Unexpected error while releasing the connection")
|
Msg("Unexpected error while closing the connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -155,9 +428,33 @@ func (c *connPool) askForNewConn(errCh chan<- error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.releaseConn(&conn{
|
newConn := &conn{
|
||||||
Conn: co,
|
Conn: co,
|
||||||
idleAt: time.Now(),
|
br: bufio.NewReaderSize(co, bufioSize),
|
||||||
idleTimeout: c.idleConnTimeout,
|
idleAt: time.Now(),
|
||||||
})
|
idleTimeout: c.idleConnTimeout,
|
||||||
|
responseHeaderTimeout: c.responseHeaderTimeout,
|
||||||
|
RWCh: make(chan rwWithUpgrade),
|
||||||
|
ErrCh: make(chan error),
|
||||||
|
bufferPool: &c.bufferPool,
|
||||||
|
limitedReaderPool: &c.limitedReaderPool,
|
||||||
|
}
|
||||||
|
go newConn.readLoop()
|
||||||
|
|
||||||
|
c.releaseConn(newConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBodyAllowedForStatus reports whether a given response status code
|
||||||
|
// permits a body. See RFC 7230, section 3.3.
|
||||||
|
// From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L459
|
||||||
|
func isBodyAllowedForStatus(status int) bool {
|
||||||
|
switch {
|
||||||
|
case status >= 100 && status <= 199:
|
||||||
|
return false
|
||||||
|
case status == 204:
|
||||||
|
return false
|
||||||
|
case status == 304:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ func TestConnPool_ConnReuse(t *testing.T) {
|
||||||
return &net.TCPConn{}, nil
|
return &net.TCPConn{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pool := newConnPool(2, 0, dialer)
|
pool := newConnPool(2, 0, 0, dialer)
|
||||||
test.poolFn(pool)
|
test.poolFn(pool)
|
||||||
|
|
||||||
assert.Equal(t, test.expected, connAlloc)
|
assert.Equal(t, test.expected, connAlloc)
|
||||||
|
@ -102,13 +102,16 @@ func TestConnPool_MaxIdleConn(t *testing.T) {
|
||||||
var keepOpenedConn int
|
var keepOpenedConn int
|
||||||
dialer := func() (net.Conn, error) {
|
dialer := func() (net.Conn, error) {
|
||||||
keepOpenedConn++
|
keepOpenedConn++
|
||||||
return &mockConn{closeFn: func() error {
|
return &mockConn{
|
||||||
keepOpenedConn--
|
doneCh: make(chan struct{}),
|
||||||
return nil
|
closeFn: func() error {
|
||||||
}}, nil
|
keepOpenedConn--
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pool := newConnPool(test.maxIdleConn, 0, dialer)
|
pool := newConnPool(test.maxIdleConn, 0, 0, dialer)
|
||||||
test.poolFn(pool)
|
test.poolFn(pool)
|
||||||
|
|
||||||
assert.Equal(t, test.expected, keepOpenedConn)
|
assert.Equal(t, test.expected, keepOpenedConn)
|
||||||
|
@ -129,7 +132,7 @@ func TestGC(t *testing.T) {
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pools["test"] = newConnPool(10, 1*time.Second, dialer)
|
pools["test"] = newConnPool(10, 1*time.Second, 0, dialer)
|
||||||
runtime.SetFinalizer(pools["test"], func(p *connPool) {
|
runtime.SetFinalizer(pools["test"], func(p *connPool) {
|
||||||
isDestroyed = true
|
isDestroyed = true
|
||||||
})
|
})
|
||||||
|
@ -149,10 +152,12 @@ func TestGC(t *testing.T) {
|
||||||
|
|
||||||
type mockConn struct {
|
type mockConn struct {
|
||||||
closeFn func() error
|
closeFn func() error
|
||||||
|
doneCh chan struct{} // makes sure that the readLoop is blocking avoiding close.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConn) Read(_ []byte) (n int, err error) {
|
func (m *mockConn) Read(_ []byte) (n int, err error) {
|
||||||
panic("implement me")
|
<-m.doneCh
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConn) Write(_ []byte) (n int, err error) {
|
func (m *mockConn) Write(_ []byte) (n int, err error) {
|
||||||
|
@ -160,6 +165,7 @@ func (m *mockConn) Write(_ []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConn) Close() error {
|
func (m *mockConn) Close() error {
|
||||||
|
defer close(m.doneCh)
|
||||||
if m.closeFn != nil {
|
if m.closeFn != nil {
|
||||||
return m.closeFn()
|
return m.closeFn()
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,18 +4,14 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||||
|
@ -57,15 +53,6 @@ func (p *pool[T]) Put(x T) {
|
||||||
p.pool.Put(x)
|
p.pool.Put(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
type buffConn struct {
|
|
||||||
*bufio.Reader
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b buffConn) Read(p []byte) (int, error) {
|
|
||||||
return b.Reader.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
type writeDetector struct {
|
type writeDetector struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
|
@ -112,21 +99,17 @@ type ReverseProxy struct {
|
||||||
|
|
||||||
connPool *connPool
|
connPool *connPool
|
||||||
|
|
||||||
bufferPool pool[[]byte]
|
writerPool pool[*bufio.Writer]
|
||||||
readerPool pool[*bufio.Reader]
|
|
||||||
writerPool pool[*bufio.Writer]
|
|
||||||
limitReaderPool pool[*io.LimitedReader]
|
|
||||||
|
|
||||||
proxyAuth string
|
proxyAuth string
|
||||||
|
|
||||||
targetURL *url.URL
|
targetURL *url.URL
|
||||||
passHostHeader bool
|
passHostHeader bool
|
||||||
preservePath bool
|
preservePath bool
|
||||||
responseHeaderTimeout time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReverseProxy creates a new ReverseProxy.
|
// NewReverseProxy creates a new ReverseProxy.
|
||||||
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
|
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, connPool *connPool) (*ReverseProxy, error) {
|
||||||
var proxyAuth string
|
var proxyAuth string
|
||||||
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
|
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
|
||||||
username := proxyURL.User.Username()
|
username := proxyURL.User.Username()
|
||||||
|
@ -135,13 +118,12 @@ func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preser
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ReverseProxy{
|
return &ReverseProxy{
|
||||||
debug: debug,
|
debug: debug,
|
||||||
passHostHeader: passHostHeader,
|
passHostHeader: passHostHeader,
|
||||||
preservePath: preservePath,
|
preservePath: preservePath,
|
||||||
targetURL: targetURL,
|
targetURL: targetURL,
|
||||||
proxyAuth: proxyAuth,
|
proxyAuth: proxyAuth,
|
||||||
connPool: connPool,
|
connPool: connPool,
|
||||||
responseHeaderTimeout: responseHeaderTimeout,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,8 +255,15 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
|
||||||
return fmt.Errorf("acquire connection: %w", err)
|
return fmt.Errorf("acquire connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Before writing the request,
|
||||||
|
// we mark the conn as expecting to handle a response.
|
||||||
|
co.expectedResponse.Store(true)
|
||||||
|
|
||||||
wd := &writeDetector{Conn: co}
|
wd := &writeDetector{Conn: co}
|
||||||
|
|
||||||
|
// TODO: do not wait to write the full request before reading the response (to handle "100 Continue").
|
||||||
|
// TODO: this is currently impossible with fasthttp to write the request partially (headers only).
|
||||||
|
// Currently, writing the request fully is a mandatory step before handling the response.
|
||||||
err = p.writeRequest(wd, outReq)
|
err = p.writeRequest(wd, outReq)
|
||||||
if wd.written && trace != nil && trace.WroteRequest != nil {
|
if wd.written && trace != nil && trace.WroteRequest != nil {
|
||||||
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
|
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
|
||||||
|
@ -293,169 +282,17 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
br := p.readerPool.Get()
|
// Sending the responseWriter unlocks the connection readLoop, to handle the response.
|
||||||
if br == nil {
|
co.RWCh <- rwWithUpgrade{
|
||||||
br = bufio.NewReaderSize(co, bufioSize)
|
RW: rw,
|
||||||
}
|
Upgrade: upgradeResponseHandler(req.Context(), reqUpType),
|
||||||
defer p.readerPool.Put(br)
|
|
||||||
|
|
||||||
br.Reset(co)
|
|
||||||
|
|
||||||
res := fasthttp.AcquireResponse()
|
|
||||||
defer fasthttp.ReleaseResponse(res)
|
|
||||||
|
|
||||||
res.Header.SetNoDefaultContentType(true)
|
|
||||||
|
|
||||||
for {
|
|
||||||
var timer *time.Timer
|
|
||||||
errTimeout := atomic.Pointer[timeoutError]{}
|
|
||||||
if p.responseHeaderTimeout > 0 {
|
|
||||||
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
|
|
||||||
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
|
||||||
co.Close()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Header.SetNoDefaultContentType(true)
|
|
||||||
if err := res.Header.Read(br); err != nil {
|
|
||||||
if p.responseHeaderTimeout > 0 {
|
|
||||||
if errT := errTimeout.Load(); errT != nil {
|
|
||||||
return errT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
co.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if timer != nil {
|
|
||||||
timer.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
fixPragmaCacheControl(&res.Header)
|
|
||||||
|
|
||||||
resCode := res.StatusCode()
|
|
||||||
is1xx := 100 <= resCode && resCode <= 199
|
|
||||||
// treat 101 as a terminal status, see issue 26161
|
|
||||||
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
|
||||||
if is1xxNonTerminal {
|
|
||||||
removeConnectionHeaders(&res.Header)
|
|
||||||
h := rw.Header()
|
|
||||||
|
|
||||||
for _, header := range hopHeaders {
|
|
||||||
res.Header.Del(header)
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Header.VisitAll(func(key, value []byte) {
|
|
||||||
rw.Header().Add(string(key), string(value))
|
|
||||||
})
|
|
||||||
|
|
||||||
rw.WriteHeader(res.StatusCode())
|
|
||||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
|
||||||
for k := range h {
|
|
||||||
delete(h, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Reset()
|
|
||||||
res.Header.Reset()
|
|
||||||
res.Header.SetNoDefaultContentType(true)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
announcedTrailers := res.Header.Peek("Trailer")
|
if err := <-co.ErrCh; err != nil {
|
||||||
|
|
||||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
|
||||||
if res.StatusCode() == http.StatusSwitchingProtocols {
|
|
||||||
// As the connection has been hijacked, it cannot be added back to the pool.
|
|
||||||
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
removeConnectionHeaders(&res.Header)
|
|
||||||
|
|
||||||
for _, header := range hopHeaders {
|
|
||||||
res.Header.Del(header)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(announcedTrailers) > 0 {
|
|
||||||
res.Header.Add("Trailer", string(announcedTrailers))
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Header.VisitAll(func(key, value []byte) {
|
|
||||||
rw.Header().Add(string(key), string(value))
|
|
||||||
})
|
|
||||||
|
|
||||||
rw.WriteHeader(res.StatusCode())
|
|
||||||
|
|
||||||
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
|
||||||
if res.Header.ContentLength() == -1 {
|
|
||||||
cbr := httputil.NewChunkedReader(br)
|
|
||||||
|
|
||||||
b := p.bufferPool.Get()
|
|
||||||
if b == nil {
|
|
||||||
b = make([]byte, bufferSize)
|
|
||||||
}
|
|
||||||
defer p.bufferPool.Put(b)
|
|
||||||
|
|
||||||
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
|
|
||||||
co.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Header.Reset()
|
|
||||||
res.Header.SetNoDefaultContentType(true)
|
|
||||||
if err := res.Header.ReadTrailer(br); err != nil {
|
|
||||||
co.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Header.Len() > 0 {
|
|
||||||
var announcedTrailersKey []string
|
|
||||||
if len(announcedTrailers) > 0 {
|
|
||||||
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Header.VisitAll(func(key, value []byte) {
|
|
||||||
for _, s := range announcedTrailersKey {
|
|
||||||
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
|
||||||
rw.Header().Add(string(key), string(value))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
p.connPool.ReleaseConn(co)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
brl := p.limitReaderPool.Get()
|
|
||||||
if brl == nil {
|
|
||||||
brl = &io.LimitedReader{}
|
|
||||||
}
|
|
||||||
defer p.limitReaderPool.Put(brl)
|
|
||||||
|
|
||||||
brl.R = br
|
|
||||||
brl.N = int64(res.Header.ContentLength())
|
|
||||||
|
|
||||||
b := p.bufferPool.Get()
|
|
||||||
if b == nil {
|
|
||||||
b = make([]byte, bufferSize)
|
|
||||||
}
|
|
||||||
defer p.bufferPool.Put(b)
|
|
||||||
|
|
||||||
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
|
|
||||||
co.Close()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.connPool.ReleaseConn(co)
|
p.connPool.ReleaseConn(co)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -362,7 +362,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||||
|
|
||||||
u := parseURI(t, srv.URL)
|
u := parseURI(t, srv.URL)
|
||||||
|
|
||||||
f, err := NewReverseProxy(u, nil, true, false, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
|
f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
|
||||||
return net.Dial("tcp", u.Host)
|
return net.Dial("tcp", u.Host)
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -434,7 +434,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
u := parseURI(t, srv.URL)
|
u := parseURI(t, srv.URL)
|
||||||
f, err := NewReverseProxy(u, nil, true, false, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
|
f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
|
||||||
return net.Dial("tcp", u.Host)
|
return net.Dial("tcp", u.Host)
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -663,7 +663,7 @@ func parseURI(t *testing.T, uri string) *url.URL {
|
||||||
|
|
||||||
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
|
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
|
||||||
u := testhelpers.MustParseURL(target)
|
u := testhelpers.MustParseURL(target)
|
||||||
return newConnPool(200, 0, func() (net.Conn, error) {
|
return newConnPool(200, 0, 0, func() (net.Conn, error) {
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
return tls.Dial("tcp", u.Host, tlsConfig)
|
return tls.Dial("tcp", u.Host, tlsConfig)
|
||||||
}
|
}
|
||||||
|
@ -676,7 +676,7 @@ func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptes
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
u := parseURI(t, uri)
|
u := parseURI(t, uri)
|
||||||
proxy, err := NewReverseProxy(u, nil, false, true, false, 0, pool)
|
proxy, err := NewReverseProxy(u, nil, false, true, false, pool)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package fast
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -19,72 +20,75 @@ type switchProtocolCopier struct {
|
||||||
user, backend io.ReadWriter
|
user, backend io.ReadWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
func (c switchProtocolCopier) copyFromBackend(errCh chan<- error) {
|
||||||
_, err := io.Copy(c.user, c.backend)
|
_, err := io.Copy(c.user, c.backend)
|
||||||
errc <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
func (c switchProtocolCopier) copyToBackend(errCh chan<- error) {
|
||||||
_, err := io.Copy(c.backend, c.user)
|
_, err := io.Copy(c.backend, c.user)
|
||||||
errc <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) {
|
type upgradeHandler func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn)
|
||||||
defer backConn.Close()
|
|
||||||
|
|
||||||
resUpType := upgradeTypeFastHTTP(&res.Header)
|
func upgradeResponseHandler(ctx context.Context, reqUpType string) upgradeHandler {
|
||||||
|
return func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) {
|
||||||
|
resUpType := upgradeTypeFastHTTP(&res.Header)
|
||||||
|
|
||||||
if !strings.EqualFold(reqUpType, resUpType) {
|
if !strings.EqualFold(reqUpType, resUpType) {
|
||||||
httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||||
return
|
backConn.Close()
|
||||||
}
|
return
|
||||||
|
|
||||||
hj, ok := rw.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
backConnCloseCh := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
// Ensure that the cancellation of a request closes the backend.
|
|
||||||
// See issue https://golang.org/issue/35559.
|
|
||||||
select {
|
|
||||||
case <-req.Context().Done():
|
|
||||||
case <-backConnCloseCh:
|
|
||||||
}
|
}
|
||||||
_ = backConn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer close(backConnCloseCh)
|
hj, ok := rw.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
conn, brw, err := hj.Hijack()
|
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||||
if err != nil {
|
backConn.Close()
|
||||||
httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err))
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
for k, values := range rw.Header() {
|
|
||||||
for _, v := range values {
|
|
||||||
res.Header.Add(k, v)
|
|
||||||
}
|
}
|
||||||
}
|
backConnCloseCh := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
// Ensure that the cancellation of a request closes the backend.
|
||||||
|
// See issue https://golang.org/issue/35559.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-backConnCloseCh:
|
||||||
|
}
|
||||||
|
_ = backConn.Close()
|
||||||
|
}()
|
||||||
|
defer close(backConnCloseCh)
|
||||||
|
|
||||||
if err := res.Header.Write(brw.Writer); err != nil {
|
conn, brw, err := hj.Hijack()
|
||||||
httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err))
|
if err != nil {
|
||||||
return
|
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("hijack failed on protocol switch: %w", err))
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
if err := brw.Flush(); err != nil {
|
for k, values := range rw.Header() {
|
||||||
httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err))
|
for _, v := range values {
|
||||||
return
|
res.Header.Add(k, v)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
errc := make(chan error, 1)
|
if err := res.Header.Write(brw.Writer); err != nil {
|
||||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response write: %w", err))
|
||||||
go spc.copyToBackend(errc)
|
return
|
||||||
go spc.copyFromBackend(errc)
|
}
|
||||||
<-errc
|
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response flush: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
|
go spc.copyToBackend(errCh)
|
||||||
|
go spc.copyFromBackend(errCh)
|
||||||
|
<-errCh
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func upgradeType(h http.Header) string {
|
func upgradeType(h http.Header) string {
|
||||||
|
|
|
@ -102,9 +102,14 @@ func isWebSocketUpgrade(req *http.Request) bool {
|
||||||
|
|
||||||
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
|
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
|
||||||
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
|
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
ErrorHandlerWithContext(req.Context(), w, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorHandlerWithContext is the http.Handler called when something goes wrong when forwarding the request.
|
||||||
|
func ErrorHandlerWithContext(ctx context.Context, w http.ResponseWriter, err error) {
|
||||||
statusCode := ComputeStatusCode(err)
|
statusCode := ComputeStatusCode(err)
|
||||||
|
|
||||||
logger := log.Ctx(req.Context())
|
logger := log.Ctx(ctx)
|
||||||
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
|
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
|
||||||
|
|
||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
|
|
Loading…
Reference in a new issue