From e3ed52ba7caa723bc34a7618652c2b7c5c4bcfef Mon Sep 17 00:00:00 2001 From: Kevin Pollet Date: Fri, 25 Oct 2024 14:26:04 +0200 Subject: [PATCH] Detect and drop broken conns in the fastproxy pool Co-authored-by: Romain --- pkg/proxy/fast/builder.go | 11 +- pkg/proxy/fast/connpool.go | 341 +++++++++++++++++++++++-- pkg/proxy/fast/connpool_test.go | 22 +- pkg/proxy/fast/proxy.go | 209 ++------------- pkg/proxy/fast/proxy_websocket_test.go | 8 +- pkg/proxy/fast/upgrade.go | 108 ++++---- pkg/proxy/httputil/proxy.go | 7 +- 7 files changed, 426 insertions(+), 280 deletions(-) diff --git a/pkg/proxy/fast/builder.go b/pkg/proxy/fast/builder.go index c53363dad..b41323dda 100644 --- a/pkg/proxy/fast/builder.go +++ b/pkg/proxy/fast/builder.go @@ -79,18 +79,13 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader, return nil, fmt.Errorf("getting ServersTransport: %w", err) } - var responseHeaderTimeout time.Duration - if cfg.ForwardingTimeouts != nil { - responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout) - } - tlsConfig, err := r.transportManager.GetTLSConfig(cfgName) if err != nil { return nil, fmt.Errorf("getting TLS config: %w", err) } pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL) - return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, responseHeaderTimeout, pool) + return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, pool) } func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool { @@ -106,9 +101,11 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, idleConnTimeout := 90 * time.Second dialTimeout := 30 * time.Second + var responseHeaderTimeout time.Duration if config.ForwardingTimeouts != nil { idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout) dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout) + responseHeaderTimeout = time.Duration(config.ForwardingTimeouts.ResponseHeaderTimeout) } proxyDialer := newDialer(dialerConfig{ @@ -119,7 +116,7 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, ProxyURL: proxyURL, }, tlsConfig) - connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) { + connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, responseHeaderTimeout, func() (net.Conn, error) { return proxyDialer.Dial("tcp", addrFromURL(targetURL)) }) diff --git a/pkg/proxy/fast/connpool.go b/pkg/proxy/fast/connpool.go index e0d2c4e7f..375fac2fb 100644 --- a/pkg/proxy/fast/connpool.go +++ b/pkg/proxy/fast/connpool.go @@ -1,42 +1,309 @@ package fast import ( + "bufio" + "errors" "fmt" + "io" "net" + "net/http" + "net/http/httputil" + "strings" + "sync" + "sync/atomic" "time" "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" ) +// rwWithUpgrade contains a ResponseWriter and an upgradeHandler, +// used to upgrade the connection (e.g. Websockets). +type rwWithUpgrade struct { + RW http.ResponseWriter + Upgrade upgradeHandler +} + // conn is an enriched net.Conn. type conn struct { net.Conn + RWCh chan rwWithUpgrade + ErrCh chan error + + br *bufio.Reader + idleAt time.Time // the last time it was marked as idle. idleTimeout time.Duration + + responseHeaderTimeout time.Duration + + expectedResponse atomic.Bool + broken atomic.Bool + upgraded atomic.Bool + + closeMu sync.Mutex + closed bool + closeErr error + + bufferPool *pool[[]byte] + limitedReaderPool *pool[*io.LimitedReader] } -func (c *conn) isExpired() bool { +// Read reads data from the connection. +// Overrides conn Read to use the buffered reader. +func (c *conn) Read(b []byte) (n int, err error) { + return c.br.Read(b) +} + +// Close closes the connection. +// Ensures that connection is closed only once, +// to avoid duplicate close error. +func (c *conn) Close() error { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.closed { + return c.closeErr + } + + c.closed = true + c.closeErr = c.Conn.Close() + + return c.closeErr +} + +// isStale returns whether the connection is in an invalid state (i.e. expired/broken). +func (c *conn) isStale() bool { expTime := c.idleAt.Add(c.idleTimeout) - return c.idleTimeout > 0 && time.Now().After(expTime) + return c.idleTimeout > 0 && time.Now().After(expTime) || c.broken.Load() +} + +// isUpgraded returns whether this connection has been upgraded (e.g. Websocket). +// An upgraded connection should not be reused and putted back in the connection pool. +func (c *conn) isUpgraded() bool { + return c.upgraded.Load() +} + +// readLoop handles the successive HTTP response read operations on the connection, +// and watches for unsolicited bytes or connection errors when idle. +func (c *conn) readLoop() { + defer c.Close() + + for { + _, err := c.br.Peek(1) + if err != nil { + select { + // An error occurred while a response was expected to be handled. + case <-c.RWCh: + c.ErrCh <- err + // An error occurred on an idle connection. + default: + c.broken.Store(true) + } + return + } + + // Unsolicited response received on an idle connection. + if !c.expectedResponse.Load() { + c.broken.Store(true) + return + } + + r := <-c.RWCh + if err = c.handleResponse(r); err != nil { + c.ErrCh <- err + return + } + + c.expectedResponse.Store(false) + c.ErrCh <- nil + } +} + +func (c *conn) handleResponse(r rwWithUpgrade) error { + res := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(res) + + res.Header.SetNoDefaultContentType(true) + + for { + var ( + timer *time.Timer + errTimeout atomic.Pointer[timeoutError] + ) + if c.responseHeaderTimeout > 0 { + timer = time.AfterFunc(c.responseHeaderTimeout, func() { + errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")}) + c.Close() // This close call is needed to interrupt the read operation below when the timeout is over. + }) + } + + res.Header.SetNoDefaultContentType(true) + if err := res.Header.Read(c.br); err != nil { + if c.responseHeaderTimeout > 0 { + if errT := errTimeout.Load(); errT != nil { + return errT + } + } + return err + } + + if timer != nil { + timer.Stop() + } + + fixPragmaCacheControl(&res.Header) + + resCode := res.StatusCode() + is1xx := 100 <= resCode && resCode <= 199 + // treat 101 as a terminal status, see issue 26161 + is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols + if is1xxNonTerminal { + removeConnectionHeaders(&res.Header) + h := r.RW.Header() + + for _, header := range hopHeaders { + res.Header.Del(header) + } + + res.Header.VisitAll(func(key, value []byte) { + r.RW.Header().Add(string(key), string(value)) + }) + + r.RW.WriteHeader(res.StatusCode()) + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + for k := range h { + delete(h, k) + } + + res.Reset() + res.Header.Reset() + res.Header.SetNoDefaultContentType(true) + + continue + } + break + } + + announcedTrailers := res.Header.Peek("Trailer") + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode() == http.StatusSwitchingProtocols { + r.Upgrade(r.RW, res, c) + c.upgraded.Store(true) // As the connection has been upgraded, it cannot be added back to the pool. + return nil + } + + removeConnectionHeaders(&res.Header) + + for _, header := range hopHeaders { + res.Header.Del(header) + } + + if len(announcedTrailers) > 0 { + res.Header.Add("Trailer", string(announcedTrailers)) + } + + res.Header.VisitAll(func(key, value []byte) { + r.RW.Header().Add(string(key), string(value)) + }) + + r.RW.WriteHeader(res.StatusCode()) + + if res.Header.ContentLength() == 0 { + return nil + } + + // When a body is not allowed for a given status code the body is ignored. + // The connection will be marked as broken by the next Peek in the readloop. + if !isBodyAllowedForStatus(res.StatusCode()) { + return nil + } + + // Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received. + if res.Header.ContentLength() == -1 { + cbr := httputil.NewChunkedReader(c.br) + + b := c.bufferPool.Get() + if b == nil { + b = make([]byte, bufferSize) + } + defer c.bufferPool.Put(b) + + if _, err := io.CopyBuffer(&writeFlusher{r.RW}, cbr, b); err != nil { + return err + } + + res.Header.Reset() + res.Header.SetNoDefaultContentType(true) + if err := res.Header.ReadTrailer(c.br); err != nil { + return err + } + + if res.Header.Len() > 0 { + var announcedTrailersKey []string + if len(announcedTrailers) > 0 { + announcedTrailersKey = strings.Split(string(announcedTrailers), ",") + } + + res.Header.VisitAll(func(key, value []byte) { + for _, s := range announcedTrailersKey { + if strings.EqualFold(s, strings.TrimSpace(string(key))) { + r.RW.Header().Add(string(key), string(value)) + return + } + } + + r.RW.Header().Add(http.TrailerPrefix+string(key), string(value)) + }) + } + + return nil + } + + brl := c.limitedReaderPool.Get() + if brl == nil { + brl = &io.LimitedReader{} + } + defer c.limitedReaderPool.Put(brl) + + brl.R = c.br + brl.N = int64(res.Header.ContentLength()) + + b := c.bufferPool.Get() + if b == nil { + b = make([]byte, bufferSize) + } + defer c.bufferPool.Put(b) + + if _, err := io.CopyBuffer(r.RW, brl, b); err != nil { + return err + } + + return nil } // connPool is a net.Conn pool implementation using channels. type connPool struct { - dialer func() (net.Conn, error) - idleConns chan *conn - idleConnTimeout time.Duration - ticker *time.Ticker - doneCh chan struct{} + dialer func() (net.Conn, error) + idleConns chan *conn + idleConnTimeout time.Duration + responseHeaderTimeout time.Duration + ticker *time.Ticker + bufferPool pool[[]byte] + limitedReaderPool pool[*io.LimitedReader] + doneCh chan struct{} } // newConnPool creates a new connPool. -func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool { +func newConnPool(maxIdleConn int, idleConnTimeout, responseHeaderTimeout time.Duration, dialer func() (net.Conn, error)) *connPool { c := &connPool{ - dialer: dialer, - idleConns: make(chan *conn, maxIdleConn), - idleConnTimeout: idleConnTimeout, - doneCh: make(chan struct{}), + dialer: dialer, + idleConns: make(chan *conn, maxIdleConn), + idleConnTimeout: idleConnTimeout, + responseHeaderTimeout: responseHeaderTimeout, + doneCh: make(chan struct{}), } if idleConnTimeout > 0 { @@ -72,22 +339,28 @@ func (c *connPool) AcquireConn() (*conn, error) { return nil, err } - if !co.isExpired() { + if !co.isStale() { return co, nil } - // As the acquired conn is expired we can close it + // As the acquired conn is stale we can close it // without putting it again into the pool. if err := co.Close(); err != nil { log.Debug(). Err(err). - Msg("Unexpected error while releasing the connection") + Msg("Unexpected error while closing the connection") } } } // ReleaseConn releases the given net.Conn to the pool. func (c *connPool) ReleaseConn(co *conn) { + // An upgraded connection cannot be safely reused for another roundTrip, + // thus we are not putting it back to the pool. + if co.isUpgraded() { + return + } + co.idleAt = time.Now() c.releaseConn(co) } @@ -97,7 +370,7 @@ func (c *connPool) cleanIdleConns() { for { select { case co := <-c.idleConns: - if !co.isExpired() { + if !co.isStale() { c.releaseConn(co) return } @@ -105,7 +378,7 @@ func (c *connPool) cleanIdleConns() { if err := co.Close(); err != nil { log.Debug(). Err(err). - Msg("Unexpected error while releasing the connection") + Msg("Unexpected error while closing the connection") } default: @@ -155,9 +428,33 @@ func (c *connPool) askForNewConn(errCh chan<- error) { return } - c.releaseConn(&conn{ - Conn: co, - idleAt: time.Now(), - idleTimeout: c.idleConnTimeout, - }) + newConn := &conn{ + Conn: co, + br: bufio.NewReaderSize(co, bufioSize), + idleAt: time.Now(), + idleTimeout: c.idleConnTimeout, + responseHeaderTimeout: c.responseHeaderTimeout, + RWCh: make(chan rwWithUpgrade), + ErrCh: make(chan error), + bufferPool: &c.bufferPool, + limitedReaderPool: &c.limitedReaderPool, + } + go newConn.readLoop() + + c.releaseConn(newConn) +} + +// isBodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +// From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L459 +func isBodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true } diff --git a/pkg/proxy/fast/connpool_test.go b/pkg/proxy/fast/connpool_test.go index 696e62a36..7532a3e3b 100644 --- a/pkg/proxy/fast/connpool_test.go +++ b/pkg/proxy/fast/connpool_test.go @@ -58,7 +58,7 @@ func TestConnPool_ConnReuse(t *testing.T) { return &net.TCPConn{}, nil } - pool := newConnPool(2, 0, dialer) + pool := newConnPool(2, 0, 0, dialer) test.poolFn(pool) assert.Equal(t, test.expected, connAlloc) @@ -102,13 +102,16 @@ func TestConnPool_MaxIdleConn(t *testing.T) { var keepOpenedConn int dialer := func() (net.Conn, error) { keepOpenedConn++ - return &mockConn{closeFn: func() error { - keepOpenedConn-- - return nil - }}, nil + return &mockConn{ + doneCh: make(chan struct{}), + closeFn: func() error { + keepOpenedConn-- + return nil + }, + }, nil } - pool := newConnPool(test.maxIdleConn, 0, dialer) + pool := newConnPool(test.maxIdleConn, 0, 0, dialer) test.poolFn(pool) assert.Equal(t, test.expected, keepOpenedConn) @@ -129,7 +132,7 @@ func TestGC(t *testing.T) { return c, nil } - pools["test"] = newConnPool(10, 1*time.Second, dialer) + pools["test"] = newConnPool(10, 1*time.Second, 0, dialer) runtime.SetFinalizer(pools["test"], func(p *connPool) { isDestroyed = true }) @@ -149,10 +152,12 @@ func TestGC(t *testing.T) { type mockConn struct { closeFn func() error + doneCh chan struct{} // makes sure that the readLoop is blocking avoiding close. } func (m *mockConn) Read(_ []byte) (n int, err error) { - panic("implement me") + <-m.doneCh + return 0, nil } func (m *mockConn) Write(_ []byte) (n int, err error) { @@ -160,6 +165,7 @@ func (m *mockConn) Write(_ []byte) (n int, err error) { } func (m *mockConn) Close() error { + defer close(m.doneCh) if m.closeFn != nil { return m.closeFn() } diff --git a/pkg/proxy/fast/proxy.go b/pkg/proxy/fast/proxy.go index 388a4d73c..a400ce646 100644 --- a/pkg/proxy/fast/proxy.go +++ b/pkg/proxy/fast/proxy.go @@ -4,18 +4,14 @@ import ( "bufio" "bytes" "encoding/base64" - "errors" "fmt" "io" "net" "net/http" "net/http/httptrace" - "net/http/httputil" "net/url" "strings" "sync" - "sync/atomic" - "time" "github.com/rs/zerolog/log" proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil" @@ -57,15 +53,6 @@ func (p *pool[T]) Put(x T) { p.pool.Put(x) } -type buffConn struct { - *bufio.Reader - net.Conn -} - -func (b buffConn) Read(p []byte) (int, error) { - return b.Reader.Read(p) -} - type writeDetector struct { net.Conn @@ -112,21 +99,17 @@ type ReverseProxy struct { connPool *connPool - bufferPool pool[[]byte] - readerPool pool[*bufio.Reader] - writerPool pool[*bufio.Writer] - limitReaderPool pool[*io.LimitedReader] + writerPool pool[*bufio.Writer] proxyAuth string - targetURL *url.URL - passHostHeader bool - preservePath bool - responseHeaderTimeout time.Duration + targetURL *url.URL + passHostHeader bool + preservePath bool } // NewReverseProxy creates a new ReverseProxy. -func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) { +func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, connPool *connPool) (*ReverseProxy, error) { var proxyAuth string if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" { username := proxyURL.User.Username() @@ -135,13 +118,12 @@ func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preser } return &ReverseProxy{ - debug: debug, - passHostHeader: passHostHeader, - preservePath: preservePath, - targetURL: targetURL, - proxyAuth: proxyAuth, - connPool: connPool, - responseHeaderTimeout: responseHeaderTimeout, + debug: debug, + passHostHeader: passHostHeader, + preservePath: preservePath, + targetURL: targetURL, + proxyAuth: proxyAuth, + connPool: connPool, }, nil } @@ -273,8 +255,15 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR return fmt.Errorf("acquire connection: %w", err) } + // Before writing the request, + // we mark the conn as expecting to handle a response. + co.expectedResponse.Store(true) + wd := &writeDetector{Conn: co} + // TODO: do not wait to write the full request before reading the response (to handle "100 Continue"). + // TODO: this is currently impossible with fasthttp to write the request partially (headers only). + // Currently, writing the request fully is a mandatory step before handling the response. err = p.writeRequest(wd, outReq) if wd.written && trace != nil && trace.WroteRequest != nil { // WroteRequest hook is used by the tracing middleware to detect if the request has been written. @@ -293,169 +282,17 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR } } - br := p.readerPool.Get() - if br == nil { - br = bufio.NewReaderSize(co, bufioSize) - } - defer p.readerPool.Put(br) - - br.Reset(co) - - res := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(res) - - res.Header.SetNoDefaultContentType(true) - - for { - var timer *time.Timer - errTimeout := atomic.Pointer[timeoutError]{} - if p.responseHeaderTimeout > 0 { - timer = time.AfterFunc(p.responseHeaderTimeout, func() { - errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")}) - co.Close() - }) - } - - res.Header.SetNoDefaultContentType(true) - if err := res.Header.Read(br); err != nil { - if p.responseHeaderTimeout > 0 { - if errT := errTimeout.Load(); errT != nil { - return errT - } - } - co.Close() - return err - } - - if timer != nil { - timer.Stop() - } - - fixPragmaCacheControl(&res.Header) - - resCode := res.StatusCode() - is1xx := 100 <= resCode && resCode <= 199 - // treat 101 as a terminal status, see issue 26161 - is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols - if is1xxNonTerminal { - removeConnectionHeaders(&res.Header) - h := rw.Header() - - for _, header := range hopHeaders { - res.Header.Del(header) - } - - res.Header.VisitAll(func(key, value []byte) { - rw.Header().Add(string(key), string(value)) - }) - - rw.WriteHeader(res.StatusCode()) - // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses - for k := range h { - delete(h, k) - } - - res.Reset() - res.Header.Reset() - res.Header.SetNoDefaultContentType(true) - - continue - } - break + // Sending the responseWriter unlocks the connection readLoop, to handle the response. + co.RWCh <- rwWithUpgrade{ + RW: rw, + Upgrade: upgradeResponseHandler(req.Context(), reqUpType), } - announcedTrailers := res.Header.Peek("Trailer") - - // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) - if res.StatusCode() == http.StatusSwitchingProtocols { - // As the connection has been hijacked, it cannot be added back to the pool. - handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br}) - return nil - } - - removeConnectionHeaders(&res.Header) - - for _, header := range hopHeaders { - res.Header.Del(header) - } - - if len(announcedTrailers) > 0 { - res.Header.Add("Trailer", string(announcedTrailers)) - } - - res.Header.VisitAll(func(key, value []byte) { - rw.Header().Add(string(key), string(value)) - }) - - rw.WriteHeader(res.StatusCode()) - - // Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received. - if res.Header.ContentLength() == -1 { - cbr := httputil.NewChunkedReader(br) - - b := p.bufferPool.Get() - if b == nil { - b = make([]byte, bufferSize) - } - defer p.bufferPool.Put(b) - - if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil { - co.Close() - return err - } - - res.Header.Reset() - res.Header.SetNoDefaultContentType(true) - if err := res.Header.ReadTrailer(br); err != nil { - co.Close() - return err - } - - if res.Header.Len() > 0 { - var announcedTrailersKey []string - if len(announcedTrailers) > 0 { - announcedTrailersKey = strings.Split(string(announcedTrailers), ",") - } - - res.Header.VisitAll(func(key, value []byte) { - for _, s := range announcedTrailersKey { - if strings.EqualFold(s, strings.TrimSpace(string(key))) { - rw.Header().Add(string(key), string(value)) - return - } - } - - rw.Header().Add(http.TrailerPrefix+string(key), string(value)) - }) - } - - p.connPool.ReleaseConn(co) - - return nil - } - - brl := p.limitReaderPool.Get() - if brl == nil { - brl = &io.LimitedReader{} - } - defer p.limitReaderPool.Put(brl) - - brl.R = br - brl.N = int64(res.Header.ContentLength()) - - b := p.bufferPool.Get() - if b == nil { - b = make([]byte, bufferSize) - } - defer p.bufferPool.Put(b) - - if _, err := io.CopyBuffer(rw, brl, b); err != nil { - co.Close() + if err := <-co.ErrCh; err != nil { return err } p.connPool.ReleaseConn(co) - return nil } diff --git a/pkg/proxy/fast/proxy_websocket_test.go b/pkg/proxy/fast/proxy_websocket_test.go index 8297ac485..b057f8b58 100644 --- a/pkg/proxy/fast/proxy_websocket_test.go +++ b/pkg/proxy/fast/proxy_websocket_test.go @@ -362,7 +362,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { u := parseURI(t, srv.URL) - f, err := NewReverseProxy(u, nil, true, false, false, 0, newConnPool(1, 0, func() (net.Conn, error) { + f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) { return net.Dial("tcp", u.Host) })) require.NoError(t, err) @@ -434,7 +434,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) { defer srv.Close() u := parseURI(t, srv.URL) - f, err := NewReverseProxy(u, nil, true, false, false, 0, newConnPool(1, 0, func() (net.Conn, error) { + f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) { return net.Dial("tcp", u.Host) })) require.NoError(t, err) @@ -663,7 +663,7 @@ func parseURI(t *testing.T, uri string) *url.URL { func createConnectionPool(target string, tlsConfig *tls.Config) *connPool { u := testhelpers.MustParseURL(target) - return newConnPool(200, 0, func() (net.Conn, error) { + return newConnPool(200, 0, 0, func() (net.Conn, error) { if tlsConfig != nil { return tls.Dial("tcp", u.Host, tlsConfig) } @@ -676,7 +676,7 @@ func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptes t.Helper() u := parseURI(t, uri) - proxy, err := NewReverseProxy(u, nil, false, true, false, 0, pool) + proxy, err := NewReverseProxy(u, nil, false, true, false, pool) require.NoError(t, err) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { diff --git a/pkg/proxy/fast/upgrade.go b/pkg/proxy/fast/upgrade.go index a42fc97d8..7bec09e49 100644 --- a/pkg/proxy/fast/upgrade.go +++ b/pkg/proxy/fast/upgrade.go @@ -2,6 +2,7 @@ package fast import ( "bytes" + "context" "fmt" "io" "net" @@ -19,72 +20,75 @@ type switchProtocolCopier struct { user, backend io.ReadWriter } -func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { +func (c switchProtocolCopier) copyFromBackend(errCh chan<- error) { _, err := io.Copy(c.user, c.backend) - errc <- err + errCh <- err } -func (c switchProtocolCopier) copyToBackend(errc chan<- error) { +func (c switchProtocolCopier) copyToBackend(errCh chan<- error) { _, err := io.Copy(c.backend, c.user) - errc <- err + errCh <- err } -func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) { - defer backConn.Close() +type upgradeHandler func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) - resUpType := upgradeTypeFastHTTP(&res.Header) +func upgradeResponseHandler(ctx context.Context, reqUpType string) upgradeHandler { + return func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) { + resUpType := upgradeTypeFastHTTP(&res.Header) - if !strings.EqualFold(reqUpType, resUpType) { - httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) - return - } - - hj, ok := rw.(http.Hijacker) - if !ok { - httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } - backConnCloseCh := make(chan bool) - go func() { - // Ensure that the cancellation of a request closes the backend. - // See issue https://golang.org/issue/35559. - select { - case <-req.Context().Done(): - case <-backConnCloseCh: + if !strings.EqualFold(reqUpType, resUpType) { + httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + backConn.Close() + return } - _ = backConn.Close() - }() - defer close(backConnCloseCh) - - conn, brw, err := hj.Hijack() - if err != nil { - httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err)) - return - } - defer conn.Close() - - for k, values := range rw.Header() { - for _, v := range values { - res.Header.Add(k, v) + hj, ok := rw.(http.Hijacker) + if !ok { + httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + backConn.Close() + return } - } + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-ctx.Done(): + case <-backConnCloseCh: + } + _ = backConn.Close() + }() + defer close(backConnCloseCh) - if err := res.Header.Write(brw.Writer); err != nil { - httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err)) - return - } + conn, brw, err := hj.Hijack() + if err != nil { + httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("hijack failed on protocol switch: %w", err)) + return + } + defer conn.Close() - if err := brw.Flush(); err != nil { - httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err)) - return - } + for k, values := range rw.Header() { + for _, v := range values { + res.Header.Add(k, v) + } + } - errc := make(chan error, 1) - spc := switchProtocolCopier{user: conn, backend: backConn} - go spc.copyToBackend(errc) - go spc.copyFromBackend(errc) - <-errc + if err := res.Header.Write(brw.Writer); err != nil { + httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response write: %w", err)) + return + } + + if err := brw.Flush(); err != nil { + httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response flush: %w", err)) + return + } + + errCh := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errCh) + go spc.copyFromBackend(errCh) + <-errCh + } } func upgradeType(h http.Header) string { diff --git a/pkg/proxy/httputil/proxy.go b/pkg/proxy/httputil/proxy.go index 8ed35c23a..78c09f9a4 100644 --- a/pkg/proxy/httputil/proxy.go +++ b/pkg/proxy/httputil/proxy.go @@ -102,9 +102,14 @@ func isWebSocketUpgrade(req *http.Request) bool { // ErrorHandler is the http.Handler called when something goes wrong when forwarding the request. func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) { + ErrorHandlerWithContext(req.Context(), w, err) +} + +// ErrorHandlerWithContext is the http.Handler called when something goes wrong when forwarding the request. +func ErrorHandlerWithContext(ctx context.Context, w http.ResponseWriter, err error) { statusCode := ComputeStatusCode(err) - logger := log.Ctx(req.Context()) + logger := log.Ctx(ctx) logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode)) w.WriteHeader(statusCode)