traefik/pkg/server/server_entrypoint_tcp.go
2024-11-18 09:56:04 +01:00

715 lines
20 KiB
Go

package server
import (
"context"
"errors"
"expvar"
"fmt"
stdlog "log"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"syscall"
"time"
"github.com/containous/alice"
"github.com/pires/go-proxyproto"
"github.com/sirupsen/logrus"
"github.com/traefik/traefik/v2/pkg/config/static"
"github.com/traefik/traefik/v2/pkg/ip"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares"
"github.com/traefik/traefik/v2/pkg/middlewares/forwardedheaders"
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
"github.com/traefik/traefik/v2/pkg/safe"
"github.com/traefik/traefik/v2/pkg/server/router"
tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp"
"github.com/traefik/traefik/v2/pkg/server/service"
"github.com/traefik/traefik/v2/pkg/tcp"
"github.com/traefik/traefik/v2/pkg/types"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
var httpServerLogger = stdlog.New(log.WithoutContext().WriterLevel(logrus.DebugLevel), "", 0)
type key string
const (
connStateKey key = "connState"
debugConnectionEnv string = "DEBUG_CONNECTION"
)
var (
clientConnectionStates = map[string]*connState{}
clientConnectionStatesMu = sync.RWMutex{}
)
type connState struct {
State string
KeepAliveState string
Start time.Time
HTTPRequestCount int
}
type httpForwarder struct {
net.Listener
connChan chan net.Conn
errChan chan error
}
func newHTTPForwarder(ln net.Listener) *httpForwarder {
return &httpForwarder{
Listener: ln,
connChan: make(chan net.Conn),
errChan: make(chan error),
}
}
// ServeTCP uses the connection to serve it later in "Accept".
func (h *httpForwarder) ServeTCP(conn tcp.WriteCloser) {
h.connChan <- conn
}
// Accept retrieves a served connection in ServeTCP.
func (h *httpForwarder) Accept() (net.Conn, error) {
select {
case conn := <-h.connChan:
return conn, nil
case err := <-h.errChan:
return nil, err
}
}
// TCPEntryPoints holds a map of TCPEntryPoint (the entrypoint names being the keys).
type TCPEntryPoints map[string]*TCPEntryPoint
// NewTCPEntryPoints creates a new TCPEntryPoints.
func NewTCPEntryPoints(entryPointsConfig static.EntryPoints, hostResolverConfig *types.HostResolverConfig) (TCPEntryPoints, error) {
if os.Getenv(debugConnectionEnv) != "" {
expvar.Publish("clientConnectionStates", expvar.Func(func() any {
return clientConnectionStates
}))
}
serverEntryPointsTCP := make(TCPEntryPoints)
for entryPointName, config := range entryPointsConfig {
protocol, err := config.GetProtocol()
if err != nil {
return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err)
}
if protocol != "tcp" {
continue
}
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
serverEntryPointsTCP[entryPointName], err = NewTCPEntryPoint(ctx, config, hostResolverConfig)
if err != nil {
return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err)
}
}
return serverEntryPointsTCP, nil
}
// Start the server entry points.
func (eps TCPEntryPoints) Start() {
for entryPointName, serverEntryPoint := range eps {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
go serverEntryPoint.Start(ctx)
}
}
// Stop the server entry points.
func (eps TCPEntryPoints) Stop() {
var wg sync.WaitGroup
for epn, ep := range eps {
wg.Add(1)
go func(entryPointName string, entryPoint *TCPEntryPoint) {
defer wg.Done()
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
entryPoint.Shutdown(ctx)
log.FromContext(ctx).Debugf("Entry point %s closed", entryPointName)
}(epn, ep)
}
wg.Wait()
}
// Switch the TCP routers.
func (eps TCPEntryPoints) Switch(routersTCP map[string]*tcprouter.Router) {
for entryPointName, rt := range routersTCP {
eps[entryPointName].SwitchRouter(rt)
}
}
// TCPEntryPoint is the TCP server.
type TCPEntryPoint struct {
listener net.Listener
switcher *tcp.HandlerSwitcher
transportConfiguration *static.EntryPointsTransport
tracker *connectionTracker
httpServer *httpServer
httpsServer *httpServer
http3Server *http3server
}
// NewTCPEntryPoint creates a new TCPEntryPoint.
func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hostResolverConfig *types.HostResolverConfig) (*TCPEntryPoint, error) {
tracker := newConnectionTracker()
listener, err := buildListener(ctx, configuration)
if err != nil {
return nil, fmt.Errorf("error preparing server: %w", err)
}
rt, err := tcprouter.NewRouter()
if err != nil {
return nil, fmt.Errorf("error preparing tcp router: %w", err)
}
reqDecorator := requestdecorator.New(hostResolverConfig)
httpServer, err := createHTTPServer(ctx, listener, configuration, true, reqDecorator)
if err != nil {
return nil, fmt.Errorf("error preparing http server: %w", err)
}
rt.SetHTTPForwarder(httpServer.Forwarder)
httpsServer, err := createHTTPServer(ctx, listener, configuration, false, reqDecorator)
if err != nil {
return nil, fmt.Errorf("error preparing https server: %w", err)
}
h3Server, err := newHTTP3Server(ctx, configuration, httpsServer)
if err != nil {
return nil, fmt.Errorf("error preparing http3 server: %w", err)
}
rt.SetHTTPSForwarder(httpsServer.Forwarder)
tcpSwitcher := &tcp.HandlerSwitcher{}
tcpSwitcher.Switch(rt)
return &TCPEntryPoint{
listener: listener,
switcher: tcpSwitcher,
transportConfiguration: configuration.Transport,
tracker: tracker,
httpServer: httpServer,
httpsServer: httpsServer,
http3Server: h3Server,
}, nil
}
// Start starts the TCP server.
func (e *TCPEntryPoint) Start(ctx context.Context) {
logger := log.FromContext(ctx)
logger.Debug("Starting TCP Server")
if e.http3Server != nil {
go func() { _ = e.http3Server.Start() }()
}
for {
conn, err := e.listener.Accept()
if err != nil {
logger.Error(err)
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Temporary() {
continue
}
var urlErr *url.Error
if errors.As(err, &urlErr) && urlErr.Temporary() {
continue
}
e.httpServer.Forwarder.errChan <- err
e.httpsServer.Forwarder.errChan <- err
return
}
writeCloser, err := writeCloser(conn)
if err != nil {
panic(err)
}
safe.Go(func() {
// Enforce read/write deadlines at the connection level,
// because when we're peeking the first byte to determine whether we are doing TLS,
// the deadlines at the server level are not taken into account.
if e.transportConfiguration.RespondingTimeouts.ReadTimeout > 0 {
err := writeCloser.SetReadDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.ReadTimeout)))
if err != nil {
logger.Errorf("Error while setting read deadline: %v", err)
}
}
if e.transportConfiguration.RespondingTimeouts.WriteTimeout > 0 {
err = writeCloser.SetWriteDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.WriteTimeout)))
if err != nil {
logger.Errorf("Error while setting write deadline: %v", err)
}
}
e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker))
})
}
}
// Shutdown stops the TCP connections.
func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
logger := log.FromContext(ctx)
reqAcceptGraceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
graceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
var wg sync.WaitGroup
shutdownServer := func(server stoppable) {
defer wg.Done()
err := server.Shutdown(ctx)
if err == nil {
return
}
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
logger.Debugf("Server failed to shutdown within deadline because: %s", err)
if err = server.Close(); err != nil {
logger.Error(err)
}
return
}
logger.Error(err)
// We expect Close to fail again because Shutdown most likely failed when trying to close a listener.
// We still call it however, to make sure that all connections get closed as well.
server.Close()
}
if e.httpServer.Server != nil {
wg.Add(1)
go shutdownServer(e.httpServer.Server)
}
if e.httpsServer.Server != nil {
wg.Add(1)
go shutdownServer(e.httpsServer.Server)
if e.http3Server != nil {
wg.Add(1)
go shutdownServer(e.http3Server)
}
}
if e.tracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
err := e.tracker.Shutdown(ctx)
if err == nil {
return
}
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
logger.Debugf("Server failed to shutdown before deadline because: %s", err)
}
e.tracker.Close()
}()
}
wg.Wait()
cancel()
}
// SwitchRouter switches the TCP router handler.
func (e *TCPEntryPoint) SwitchRouter(rt *tcprouter.Router) {
rt.SetHTTPForwarder(e.httpServer.Forwarder)
httpHandler := rt.GetHTTPHandler()
if httpHandler == nil {
httpHandler = router.BuildDefaultHTTPRouter()
}
e.httpServer.Switcher.UpdateHandler(httpHandler)
rt.SetHTTPSForwarder(e.httpsServer.Forwarder)
httpsHandler := rt.GetHTTPSHandler()
if httpsHandler == nil {
httpsHandler = router.BuildDefaultHTTPRouter()
}
e.httpsServer.Switcher.UpdateHandler(httpsHandler)
e.switcher.Switch(rt)
if e.http3Server != nil {
e.http3Server.Switch(rt)
}
}
// writeCloserWrapper wraps together a connection, and the concrete underlying
// connection type that was found to satisfy WriteCloser.
type writeCloserWrapper struct {
net.Conn
writeCloser tcp.WriteCloser
}
func (c *writeCloserWrapper) CloseWrite() error {
return c.writeCloser.CloseWrite()
}
// writeCloser returns the given connection, augmented with the WriteCloser
// implementation, if any was found within the underlying conn.
func writeCloser(conn net.Conn) (tcp.WriteCloser, error) {
switch typedConn := conn.(type) {
case *proxyproto.Conn:
underlying, ok := typedConn.TCPConn()
if !ok {
return nil, errors.New("underlying connection is not a tcp connection")
}
return &writeCloserWrapper{writeCloser: underlying, Conn: typedConn}, nil
case *net.TCPConn:
return typedConn, nil
default:
return nil, fmt.Errorf("unknown connection type %T", typedConn)
}
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err := tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err := tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
// Some systems, such as OpenBSD, have no user-settable per-socket TCP
// keepalive options.
if !errors.Is(err, syscall.ENOPROTOOPT) {
return nil, err
}
}
return tc, nil
}
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
timeout := entryPoint.Transport.RespondingTimeouts.ReadTimeout
// proxyproto use 200ms if ReadHeaderTimeout is set to 0 and not no timeout
if timeout == 0 {
timeout = -1
}
proxyListener := &proxyproto.Listener{Listener: listener, ReadHeaderTimeout: time.Duration(timeout)}
if entryPoint.ProxyProtocol.Insecure {
log.FromContext(ctx).Infof("Enabling ProxyProtocol without trusted IPs: Insecure")
return proxyListener, nil
}
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
if err != nil {
return nil, err
}
proxyListener.Policy = func(upstream net.Addr) (proxyproto.Policy, error) {
ipAddr, ok := upstream.(*net.TCPAddr)
if !ok {
return proxyproto.REJECT, fmt.Errorf("type error %v", upstream)
}
if !checker.ContainsIP(ipAddr.IP) {
log.FromContext(ctx).Debugf("IP %s is not in trusted IPs list, ignoring ProxyProtocol Headers and bypass connection", ipAddr.IP)
return proxyproto.IGNORE, nil
}
return proxyproto.USE, nil
}
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return proxyListener, nil
}
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
listener, err := net.Listen("tcp", entryPoint.GetAddress())
if err != nil {
return nil, fmt.Errorf("error opening listener: %w", err)
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, fmt.Errorf("error creating proxy protocol listener: %w", err)
}
}
return listener, nil
}
func newConnectionTracker() *connectionTracker {
return &connectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type connectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddConnection add a connection in the tracked connections list.
func (c *connectionTracker) AddConnection(conn net.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
c.conns[conn] = struct{}{}
}
// RemoveConnection remove a connection from the tracked connections list.
func (c *connectionTracker) RemoveConnection(conn net.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.conns, conn)
}
func (c *connectionTracker) isEmpty() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return len(c.conns) == 0
}
// Shutdown wait for the connection closing.
func (c *connectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
if c.isEmpty() {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list.
func (c *connectionTracker) Close() {
c.lock.Lock()
defer c.lock.Unlock()
for conn := range c.conns {
if err := conn.Close(); err != nil {
log.WithoutContext().Errorf("Error while closing connection: %v", err)
}
delete(c.conns, conn)
}
}
type stoppable interface {
Shutdown(ctx context.Context) error
Close() error
}
type stoppableServer interface {
stoppable
Serve(listener net.Listener) error
}
type httpServer struct {
Server stoppableServer
Forwarder *httpForwarder
Switcher *middlewares.HTTPHandlerSwitcher
}
func createHTTPServer(ctx context.Context, ln net.Listener, configuration *static.EntryPoint, withH2c bool, reqDecorator *requestdecorator.RequestDecorator) (*httpServer, error) {
if configuration.HTTP2.MaxConcurrentStreams < 0 {
return nil, errors.New("max concurrent streams value must be greater than or equal to zero")
}
httpSwitcher := middlewares.NewHandlerSwitcher(router.BuildDefaultHTTPRouter())
next, err := alice.New(requestdecorator.WrapHandler(reqDecorator)).Then(httpSwitcher)
if err != nil {
return nil, err
}
var handler http.Handler
handler, err = forwardedheaders.NewXForwarded(
configuration.ForwardedHeaders.Insecure,
configuration.ForwardedHeaders.TrustedIPs,
configuration.ForwardedHeaders.Connection,
next)
if err != nil {
return nil, err
}
handler = denyFragment(handler)
if configuration.HTTP.EncodeQuerySemicolons {
handler = encodeQuerySemicolons(handler)
} else {
handler = http.AllowQuerySemicolons(handler)
}
debugConnection := os.Getenv(debugConnectionEnv) != ""
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
handler = newKeepAliveMiddleware(handler, configuration.Transport.KeepAliveMaxRequests, configuration.Transport.KeepAliveMaxTime)
}
if withH2c {
handler = h2c.NewHandler(handler, &http2.Server{
MaxConcurrentStreams: uint32(configuration.HTTP2.MaxConcurrentStreams),
})
}
serverHTTP := &http.Server{
Handler: handler,
ErrorLog: httpServerLogger,
ReadTimeout: time.Duration(configuration.Transport.RespondingTimeouts.ReadTimeout),
WriteTimeout: time.Duration(configuration.Transport.RespondingTimeouts.WriteTimeout),
IdleTimeout: time.Duration(configuration.Transport.RespondingTimeouts.IdleTimeout),
}
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
cState := &connState{Start: time.Now()}
if debugConnection {
clientConnectionStatesMu.Lock()
clientConnectionStates[getConnKey(c)] = cState
clientConnectionStatesMu.Unlock()
}
return context.WithValue(ctx, connStateKey, cState)
}
if debugConnection {
serverHTTP.ConnState = func(c net.Conn, state http.ConnState) {
clientConnectionStatesMu.Lock()
if clientConnectionStates[getConnKey(c)] != nil {
clientConnectionStates[getConnKey(c)].State = state.String()
}
clientConnectionStatesMu.Unlock()
}
}
}
prevConnContext := serverHTTP.ConnContext
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
// This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM.
ctx = service.AddTransportOnContext(ctx)
if prevConnContext != nil {
return prevConnContext(ctx, c)
}
return ctx
}
// ConfigureServer configures HTTP/2 with the MaxConcurrentStreams option for the given server.
// Also keeping behavior the same as
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/http/server.go;l=3262
if !strings.Contains(os.Getenv("GODEBUG"), "http2server=0") {
err = http2.ConfigureServer(serverHTTP, &http2.Server{
MaxConcurrentStreams: uint32(configuration.HTTP2.MaxConcurrentStreams),
NewWriteScheduler: func() http2.WriteScheduler { return http2.NewPriorityWriteScheduler(nil) },
})
if err != nil {
return nil, fmt.Errorf("configure HTTP/2 server: %w", err)
}
}
listener := newHTTPForwarder(ln)
go func() {
err := serverHTTP.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.FromContext(ctx).Errorf("Error while starting server: %v", err)
}
}()
return &httpServer{
Server: serverHTTP,
Forwarder: listener,
Switcher: httpSwitcher,
}, nil
}
func getConnKey(conn net.Conn) string {
return fmt.Sprintf("%s => %s", conn.RemoteAddr(), conn.LocalAddr())
}
func newTrackedConnection(conn tcp.WriteCloser, tracker *connectionTracker) *trackedConnection {
tracker.AddConnection(conn)
return &trackedConnection{
WriteCloser: conn,
tracker: tracker,
}
}
type trackedConnection struct {
tracker *connectionTracker
tcp.WriteCloser
}
func (t *trackedConnection) Close() error {
t.tracker.RemoveConnection(t.WriteCloser)
return t.WriteCloser.Close()
}
// This function is inspired by http.AllowQuerySemicolons.
func encodeQuerySemicolons(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if strings.Contains(req.URL.RawQuery, ";") {
r2 := new(http.Request)
*r2 = *req
r2.URL = new(url.URL)
*r2.URL = *req.URL
r2.URL.RawQuery = strings.ReplaceAll(req.URL.RawQuery, ";", "%3B")
// Because the reverse proxy director is building query params from requestURI it needs to be updated as well.
r2.RequestURI = r2.URL.RequestURI()
h.ServeHTTP(rw, r2)
} else {
h.ServeHTTP(rw, req)
}
})
}
// When go receives an HTTP request, it assumes the absence of fragment URL.
// However, it is still possible to send a fragment in the request.
// In this case, Traefik will encode the '#' character, altering the request's intended meaning.
// To avoid this behavior, the following function rejects requests that include a fragment in the URL.
func denyFragment(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if strings.Contains(req.URL.RawPath, "#") {
log.WithoutContext().Debugf("Rejecting request because it contains a fragment in the URL path: %s", req.URL.RawPath)
rw.WriteHeader(http.StatusBadRequest)
return
}
h.ServeHTTP(rw, req)
})
}