diff --git a/docs/content/reference/dynamic-configuration/docker-labels.yml b/docs/content/reference/dynamic-configuration/docker-labels.yml index 757e525cf..722a4fdf4 100644 --- a/docs/content/reference/dynamic-configuration/docker-labels.yml +++ b/docs/content/reference/dynamic-configuration/docker-labels.yml @@ -184,4 +184,6 @@ - "traefik.tcp.routers.tcprouter1.tls.options=foobar" - "traefik.tcp.routers.tcprouter1.tls.passthrough=true" - "traefik.tcp.services.tcpservice0.loadbalancer.server.port=foobar" +- "traefik.tcp.services.tcpservice0.loadbalancer.terminationdelay=100" - "traefik.tcp.services.tcpservice1.loadbalancer.server.port=foobar" +- "traefik.tcp.services.tcpservice1.loadbalancer.terminationdelay=100" diff --git a/docs/content/reference/dynamic-configuration/file.toml b/docs/content/reference/dynamic-configuration/file.toml index 88fc62144..8027a082b 100644 --- a/docs/content/reference/dynamic-configuration/file.toml +++ b/docs/content/reference/dynamic-configuration/file.toml @@ -286,14 +286,17 @@ [tcp.services] [tcp.services.TCPService0] [tcp.services.TCPService0.loadBalancer] + terminationDelay = 100 [[tcp.services.TCPService0.loadBalancer.servers]] address = "foobar" [[tcp.services.TCPService0.loadBalancer.servers]] address = "foobar" + [tcp.services.TCPService1] [tcp.services.TCPService1.loadBalancer] + terminationDelay = 100 [[tcp.services.TCPService1.loadBalancer.servers]] address = "foobar" diff --git a/docs/content/reference/dynamic-configuration/file.yaml b/docs/content/reference/dynamic-configuration/file.yaml index 46a54373f..3c2eb98a4 100644 --- a/docs/content/reference/dynamic-configuration/file.yaml +++ b/docs/content/reference/dynamic-configuration/file.yaml @@ -324,11 +324,13 @@ tcp: services: TCPService0: loadBalancer: + terminationDelay: 100 servers: - address: foobar - address: foobar TCPService1: loadBalancer: + terminationDelay: 100 servers: - address: foobar - address: foobar diff --git a/docs/content/reference/dynamic-configuration/marathon-labels.json b/docs/content/reference/dynamic-configuration/marathon-labels.json index 5b1360841..695e8c6bd 100644 --- a/docs/content/reference/dynamic-configuration/marathon-labels.json +++ b/docs/content/reference/dynamic-configuration/marathon-labels.json @@ -184,4 +184,6 @@ "traefik.tcp.routers.tcprouter1.tls.options": "foobar", "traefik.tcp.routers.tcprouter1.tls.passthrough": "true", "traefik.tcp.services.tcpservice0.loadbalancer.server.port": "foobar", +"traefik.tcp.services.tcpservice0.loadbalancer.terminationDelay": "100", "traefik.tcp.services.tcpservice1.loadbalancer.server.port": "foobar" +"traefik.tcp.services.tcpservice1.loadbalancer.terminationDelay": "100", diff --git a/docs/content/routing/services/index.md b/docs/content/routing/services/index.md index f452a6827..0e9e109df 100644 --- a/docs/content/routing/services/index.md +++ b/docs/content/routing/services/index.md @@ -455,3 +455,34 @@ The `address` option (IP:Port) point to a specific instance. servers: address: "xx.xx.xx.xx:xx" ``` + +#### Termination Delay + +As a proxy between a client and a server, it can happen that either side (e.g. client side) decides to terminate its writing capability on the connection (i.e. issuance of a FIN packet). +The proxy needs to propagate that intent to the other side, and so when that happens, it also does the same on its connection with the other side (e.g. backend side). + +However, if for some reason (bad implementation, or malicious intent) the other side does not eventually do the same as well, +the connection would stay half-open, which would lock resources for however long. + +To that end, as soon as the proxy enters this termination sequence, it sets a deadline on fully terminating the connections on both sides. + +The termination delay controls that deadline. +It is a duration in milliseconds, defaulting to 100. +A negative value means an infinite deadline (i.e. the connection is never fully terminated by the proxy itself). + +??? example "A Service with a termination delay -- Using the [File Provider](../../providers/file.md)" + + ```toml tab="TOML" + [tcp.services] + [tcp.services.my-service.loadBalancer] + [[tcp.services.my-service.loadBalancer]] + terminationDelay = 200 + ``` + + ```yaml tab="YAML" + tcp: + services: + my-service: + loadBalancer: + terminationDelay: 200 + ``` diff --git a/integration/testdata/rawdata-crd.json b/integration/testdata/rawdata-crd.json index 2cd67af47..98ccb952e 100644 --- a/integration/testdata/rawdata-crd.json +++ b/integration/testdata/rawdata-crd.json @@ -115,6 +115,7 @@ "tcpServices": { "default/test3.route-673acf455cb2dab0b43a@kubernetescrd": { "loadBalancer": { + "terminationDelay": 100, "servers": [ { "address": "10.42.0.4:8080" diff --git a/pkg/config/dynamic/tcp_config.go b/pkg/config/dynamic/tcp_config.go index e3e42ab68..d8ff19cb6 100644 --- a/pkg/config/dynamic/tcp_config.go +++ b/pkg/config/dynamic/tcp_config.go @@ -45,7 +45,19 @@ type RouterTCPTLSConfig struct { // TCPLoadBalancerService holds the LoadBalancerService configuration. type TCPLoadBalancerService struct { - Servers []TCPServer `json:"servers,omitempty" toml:"servers,omitempty" yaml:"servers,omitempty" label-slice-as-struct:"server"` + // TerminationDelay, corresponds to the deadline that the proxy sets, after one + // of its connected peers indicates it has closed the writing capability of its + // connection, to close the reading capability as well, hence fully terminating the + // connection. It is a duration in milliseconds, defaulting to 100. A negative value + // means an infinite deadline (i.e. the reading capability is never closed). + TerminationDelay *int `json:"terminationDelay,omitempty" toml:"terminationDelay,omitempty" yaml:"terminationDelay,omitempty"` + Servers []TCPServer `json:"servers,omitempty" toml:"servers,omitempty" yaml:"servers,omitempty" label-slice-as-struct:"server"` +} + +// SetDefaults Default values for a TCPLoadBalancerService +func (l *TCPLoadBalancerService) SetDefaults() { + defaultTerminationDelay := 100 // in milliseconds + l.TerminationDelay = &defaultTerminationDelay } // Mergeable tells if the given service is mergeable. diff --git a/pkg/config/dynamic/zz_generated.deepcopy.go b/pkg/config/dynamic/zz_generated.deepcopy.go index fa391b000..25d9d4807 100644 --- a/pkg/config/dynamic/zz_generated.deepcopy.go +++ b/pkg/config/dynamic/zz_generated.deepcopy.go @@ -1155,6 +1155,11 @@ func (in *TCPConfiguration) DeepCopy() *TCPConfiguration { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *TCPLoadBalancerService) DeepCopyInto(out *TCPLoadBalancerService) { *out = *in + if in.TerminationDelay != nil { + in, out := &in.TerminationDelay, &out.TerminationDelay + *out = new(int) + **out = **in + } if in.Servers != nil { in, out := &in.Servers, &out.Servers *out = make([]TCPServer, len(*in)) diff --git a/pkg/config/label/label_test.go b/pkg/config/label/label_test.go index 1fedf3046..989eb6be6 100644 --- a/pkg/config/label/label_test.go +++ b/pkg/config/label/label_test.go @@ -170,7 +170,9 @@ func TestDecodeConfiguration(t *testing.T) { "traefik.tcp.routers.Router1.tls.options": "foo", "traefik.tcp.routers.Router1.tls.passthrough": "false", "traefik.tcp.services.Service0.loadbalancer.server.Port": "42", + "traefik.tcp.services.Service0.loadbalancer.TerminationDelay": "42", "traefik.tcp.services.Service1.loadbalancer.server.Port": "42", + "traefik.tcp.services.Service1.loadbalancer.TerminationDelay": "42", } configuration, err := DecodeConfiguration(labels) @@ -212,6 +214,7 @@ func TestDecodeConfiguration(t *testing.T) { Port: "42", }, }, + TerminationDelay: func(i int) *int { return &i }(42), }, }, "Service1": { @@ -221,6 +224,7 @@ func TestDecodeConfiguration(t *testing.T) { Port: "42", }, }, + TerminationDelay: func(i int) *int { return &i }(42), }, }, }, diff --git a/pkg/provider/docker/config.go b/pkg/provider/docker/config.go index aa5aaf376..86f097665 100644 --- a/pkg/provider/docker/config.go +++ b/pkg/provider/docker/config.go @@ -79,6 +79,7 @@ func (p *Provider) buildTCPServiceConfiguration(ctx context.Context, container d if len(configuration.Services) == 0 { configuration.Services = make(map[string]*dynamic.TCPService) lb := &dynamic.TCPLoadBalancerService{} + lb.SetDefaults() configuration.Services[serviceName] = &dynamic.TCPService{ LoadBalancer: lb, } diff --git a/pkg/provider/docker/config_test.go b/pkg/provider/docker/config_test.go index c7c9bbdb5..2d927c0d7 100644 --- a/pkg/provider/docker/config_test.go +++ b/pkg/provider/docker/config_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" ) +func Int(v int) *int { return &v } + func TestDefaultRule(t *testing.T) { testCases := []struct { desc string @@ -2092,6 +2094,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:80", }, }, + TerminationDelay: Int(100), }, }, }, @@ -2136,6 +2139,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:80", }, }, + TerminationDelay: Int(100), }, }, }, @@ -2190,6 +2194,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:8080", }, }, + TerminationDelay: Int(100), }, }, }, @@ -2268,6 +2273,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.2:8080", }, }, + TerminationDelay: Int(100), }, }, }, @@ -2331,6 +2337,53 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:8080", }, }, + TerminationDelay: Int(100), + }, + }, + }, + }, + HTTP: &dynamic.HTTPConfiguration{ + Routers: map[string]*dynamic.Router{}, + Middlewares: map[string]*dynamic.Middleware{}, + Services: map[string]*dynamic.Service{}, + }, + }, + }, + { + desc: "tcp with label for tcp service, with termination delay", + containers: []dockerData{ + { + ServiceName: "Test", + Name: "Test", + Labels: map[string]string{ + "traefik.tcp.services.foo.loadbalancer.server.port": "8080", + "traefik.tcp.services.foo.loadbalancer.terminationdelay": "200", + }, + NetworkSettings: networkSettings{ + Ports: nat.PortMap{ + nat.Port("80/tcp"): []nat.PortBinding{}, + }, + Networks: map[string]*networkData{ + "bridge": { + Name: "bridge", + Addr: "127.0.0.1", + }, + }, + }, + }, + }, + expected: &dynamic.Configuration{ + TCP: &dynamic.TCPConfiguration{ + Routers: map[string]*dynamic.TCPRouter{}, + Services: map[string]*dynamic.TCPService{ + "foo": { + LoadBalancer: &dynamic.TCPLoadBalancerService{ + Servers: []dynamic.TCPServer{ + { + Address: "127.0.0.1:8080", + }, + }, + TerminationDelay: Int(200), }, }, }, diff --git a/pkg/provider/marathon/config.go b/pkg/provider/marathon/config.go index 3f9adfebe..08a767af4 100644 --- a/pkg/provider/marathon/config.go +++ b/pkg/provider/marathon/config.go @@ -141,6 +141,7 @@ func (p *Provider) buildTCPServiceConfiguration(ctx context.Context, app maratho if len(conf.Services) == 0 { conf.Services = make(map[string]*dynamic.TCPService) lb := &dynamic.TCPLoadBalancerService{} + lb.SetDefaults() conf.Services[appName] = &dynamic.TCPService{ LoadBalancer: lb, } diff --git a/pkg/provider/marathon/config_test.go b/pkg/provider/marathon/config_test.go index 296490f08..c2dc92b30 100644 --- a/pkg/provider/marathon/config_test.go +++ b/pkg/provider/marathon/config_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" ) +func Int(v int) *int { return &v } + func TestGetConfigurationAPIErrors(t *testing.T) { fakeClient := newFakeClient(true, marathon.Applications{}) @@ -1240,6 +1242,7 @@ func TestBuildConfiguration(t *testing.T) { Address: "localhost:80", }, }, + TerminationDelay: Int(100), }, }, }, @@ -1271,6 +1274,7 @@ func TestBuildConfiguration(t *testing.T) { Address: "localhost:80", }, }, + TerminationDelay: Int(100), }, }, }, @@ -1310,6 +1314,48 @@ func TestBuildConfiguration(t *testing.T) { Address: "localhost:8080", }, }, + TerminationDelay: Int(100), + }, + }, + }, + }, + HTTP: &dynamic.HTTPConfiguration{ + Routers: map[string]*dynamic.Router{}, + Middlewares: map[string]*dynamic.Middleware{}, + Services: map[string]*dynamic.Service{}, + }, + }, + }, + { + desc: "one app with tcp labels with port, with termination delay", + applications: withApplications( + application( + appID("/app"), + appPorts(80, 81), + withTasks(localhostTask(taskPorts(80, 81))), + withLabel("traefik.tcp.routers.foo.rule", "HostSNI(`foo.bar`)"), + withLabel("traefik.tcp.routers.foo.tls", "true"), + withLabel("traefik.tcp.services.foo.loadbalancer.server.port", "8080"), + withLabel("traefik.tcp.services.foo.loadbalancer.terminationdelay", "200"), + )), + expected: &dynamic.Configuration{ + TCP: &dynamic.TCPConfiguration{ + Routers: map[string]*dynamic.TCPRouter{ + "foo": { + Service: "foo", + Rule: "HostSNI(`foo.bar`)", + TLS: &dynamic.RouterTCPTLSConfig{}, + }, + }, + Services: map[string]*dynamic.TCPService{ + "foo": { + LoadBalancer: &dynamic.TCPLoadBalancerService{ + Servers: []dynamic.TCPServer{ + { + Address: "localhost:8080", + }, + }, + TerminationDelay: Int(200), }, }, }, @@ -1350,6 +1396,7 @@ func TestBuildConfiguration(t *testing.T) { Address: "localhost:8080", }, }, + TerminationDelay: Int(100), }, }, }, diff --git a/pkg/provider/rancher/config.go b/pkg/provider/rancher/config.go index 4d6f63a13..e2787d96d 100644 --- a/pkg/provider/rancher/config.go +++ b/pkg/provider/rancher/config.go @@ -75,6 +75,7 @@ func (p *Provider) buildTCPServiceConfiguration(ctx context.Context, service ran if len(configuration.Services) == 0 { configuration.Services = make(map[string]*dynamic.TCPService) lb := &dynamic.TCPLoadBalancerService{} + lb.SetDefaults() configuration.Services[serviceName] = &dynamic.TCPService{ LoadBalancer: lb, } diff --git a/pkg/provider/rancher/config_test.go b/pkg/provider/rancher/config_test.go index 31fb29ddc..14d172f35 100644 --- a/pkg/provider/rancher/config_test.go +++ b/pkg/provider/rancher/config_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/require" ) +func Int(v int) *int { return &v } + func Test_buildConfiguration(t *testing.T) { testCases := []struct { desc string @@ -512,6 +514,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:80", }, }, + TerminationDelay: Int(100), }, }, }, @@ -548,6 +551,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:80", }, }, + TerminationDelay: Int(100), }, }, }, @@ -590,6 +594,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:8080", }, }, + TerminationDelay: Int(100), }, }, }, @@ -638,6 +643,7 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.2:8080", }, }, + TerminationDelay: Int(100), }, }, }, @@ -693,6 +699,45 @@ func Test_buildConfiguration(t *testing.T) { Address: "127.0.0.1:8080", }, }, + TerminationDelay: Int(100), + }, + }, + }, + }, + HTTP: &dynamic.HTTPConfiguration{ + Routers: map[string]*dynamic.Router{}, + Middlewares: map[string]*dynamic.Middleware{}, + Services: map[string]*dynamic.Service{}, + }, + }, + }, + { + desc: "tcp with label for tcp service, with termination delay", + containers: []rancherData{ + { + Name: "Test", + Labels: map[string]string{ + "traefik.tcp.services.foo.loadbalancer.server.port": "8080", + "traefik.tcp.services.foo.loadbalancer.terminationdelay": "200", + }, + Port: "80/tcp", + Containers: []string{"127.0.0.1"}, + Health: "", + State: "", + }, + }, + expected: &dynamic.Configuration{ + TCP: &dynamic.TCPConfiguration{ + Routers: map[string]*dynamic.TCPRouter{}, + Services: map[string]*dynamic.TCPService{ + "foo": { + LoadBalancer: &dynamic.TCPLoadBalancerService{ + Servers: []dynamic.TCPServer{ + { + Address: "127.0.0.1:8080", + }, + }, + TerminationDelay: Int(200), }, }, }, diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index c3c4c43b8..cad56b3ab 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -37,7 +37,7 @@ func newHTTPForwarder(ln net.Listener) *httpForwarder { } // ServeTCP uses the connection to serve it later in "Accept" -func (h *httpForwarder) ServeTCP(conn net.Conn) { +func (h *httpForwarder) ServeTCP(conn tcp.WriteCloser) { h.connChan <- conn } @@ -99,7 +99,36 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*T }, nil } +// writeCloserWrapper wraps together a connection, and the concrete underlying +// connection type that was found to satisfy WriteCloser. +type writeCloserWrapper struct { + net.Conn + writeCloser tcp.WriteCloser +} + +func (c *writeCloserWrapper) CloseWrite() error { + return c.writeCloser.CloseWrite() +} + +// writeCloser returns the given connection, augmented with the WriteCloser +// implementation, if any was found within the underlying conn. +func writeCloser(conn net.Conn) (tcp.WriteCloser, error) { + switch typedConn := conn.(type) { + case *proxyprotocol.Conn: + underlying, err := writeCloser(typedConn.Conn) + if err != nil { + return nil, err + } + return &writeCloserWrapper{writeCloser: underlying, Conn: typedConn}, nil + case *net.TCPConn: + return typedConn, nil + default: + return nil, fmt.Errorf("unknown connection type %T", typedConn) + } +} + func (e *TCPEntryPoint) startTCP(ctx context.Context) { + log.FromContext(ctx).Debugf("Start TCP Server") for { @@ -109,8 +138,13 @@ func (e *TCPEntryPoint) startTCP(ctx context.Context) { return } + writeCloser, err := writeCloser(conn) + if err != nil { + panic(err) + } + safe.Go(func() { - e.switcher.ServeTCP(newTrackedConnection(conn, e.tracker)) + e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker)) }) } } @@ -374,20 +408,20 @@ func createHTTPServer(ln net.Listener, configuration *static.EntryPoint, withH2c }, nil } -func newTrackedConnection(conn net.Conn, tracker *connectionTracker) *trackedConnection { +func newTrackedConnection(conn tcp.WriteCloser, tracker *connectionTracker) *trackedConnection { tracker.AddConnection(conn) return &trackedConnection{ - Conn: conn, - tracker: tracker, + WriteCloser: conn, + tracker: tracker, } } type trackedConnection struct { tracker *connectionTracker - net.Conn + tcp.WriteCloser } func (t *trackedConnection) Close() error { - t.tracker.RemoveConnection(t.Conn) - return t.Conn.Close() + t.tracker.RemoveConnection(t.WriteCloser) + return t.WriteCloser.Close() } diff --git a/pkg/server/server_entrypoint_tcp_test.go b/pkg/server/server_entrypoint_tcp_test.go index 525cea4f9..6bdce2a55 100644 --- a/pkg/server/server_entrypoint_tcp_test.go +++ b/pkg/server/server_entrypoint_tcp_test.go @@ -113,7 +113,7 @@ func TestShutdownTCPConn(t *testing.T) { go entryPoint.startTCP(context.Background()) router := &tcp.Router{} - router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) { + router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) { _, err := http.ReadRequest(bufio.NewReader(conn)) require.NoError(t, err) time.Sleep(1 * time.Second) diff --git a/pkg/server/service/tcp/service.go b/pkg/server/service/tcp/service.go index e5311a7eb..2495e2551 100644 --- a/pkg/server/service/tcp/service.go +++ b/pkg/server/service/tcp/service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "time" "github.com/containous/traefik/v2/pkg/config/runtime" "github.com/containous/traefik/v2/pkg/log" @@ -44,13 +45,19 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han loadBalancer := tcp.NewRRLoadBalancer() + if conf.LoadBalancer.TerminationDelay == nil { + defaultTerminationDelay := 100 + conf.LoadBalancer.TerminationDelay = &defaultTerminationDelay + } + duration := time.Millisecond * time.Duration(*conf.LoadBalancer.TerminationDelay) + for name, server := range conf.LoadBalancer.Servers { if _, _, err := net.SplitHostPort(server.Address); err != nil { logger.Errorf("In service %q: %v", serviceQualifiedName, err) continue } - handler, err := tcp.NewProxy(server.Address) + handler, err := tcp.NewProxy(server.Address, duration) if err != nil { logger.Errorf("In service %q server %q: %v", serviceQualifiedName, server.Address, err) continue diff --git a/pkg/tcp/handler.go b/pkg/tcp/handler.go index 63158baee..88aefe8de 100644 --- a/pkg/tcp/handler.go +++ b/pkg/tcp/handler.go @@ -6,14 +6,23 @@ import ( // Handler is the TCP Handlers interface type Handler interface { - ServeTCP(conn net.Conn) + ServeTCP(conn WriteCloser) } // The HandlerFunc type is an adapter to allow the use of // ordinary functions as handlers. -type HandlerFunc func(conn net.Conn) +type HandlerFunc func(conn WriteCloser) // ServeTCP serves tcp -func (f HandlerFunc) ServeTCP(conn net.Conn) { +func (f HandlerFunc) ServeTCP(conn WriteCloser) { f(conn) } + +// WriteCloser describes a net.Conn with a CloseWrite method. +type WriteCloser interface { + net.Conn + // CloseWrite on a network connection, indicates that the issuer of the call + // has terminated sending on that connection. + // It corresponds to sending a FIN packet. + CloseWrite() error +} diff --git a/pkg/tcp/proxy.go b/pkg/tcp/proxy.go index 76a20aa91..2e7934415 100644 --- a/pkg/tcp/proxy.go +++ b/pkg/tcp/proxy.go @@ -3,28 +3,32 @@ package tcp import ( "io" "net" + "time" "github.com/containous/traefik/v2/pkg/log" ) // Proxy forwards a TCP request to a TCP service type Proxy struct { - target *net.TCPAddr + target *net.TCPAddr + terminationDelay time.Duration } // NewProxy creates a new Proxy -func NewProxy(address string) (*Proxy, error) { +func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) { tcpAddr, err := net.ResolveTCPAddr("tcp", address) if err != nil { return nil, err } - return &Proxy{target: tcpAddr}, nil + return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil } // ServeTCP forwards the connection to a service -func (p *Proxy) ServeTCP(conn net.Conn) { +func (p *Proxy) ServeTCP(conn WriteCloser) { log.Debugf("Handling connection from %s", conn.RemoteAddr()) + + // needed because of e.g. server.trackedConnection defer conn.Close() connBackend, err := net.DialTCP("tcp", nil, p.target) @@ -32,19 +36,35 @@ func (p *Proxy) ServeTCP(conn net.Conn) { log.Errorf("Error while connection to backend: %v", err) return } + + // maybe not needed, but just in case defer connBackend.Close() - errChan := make(chan error, 1) - go connCopy(conn, connBackend, errChan) - go connCopy(connBackend, conn, errChan) + errChan := make(chan error) + go p.connCopy(conn, connBackend, errChan) + go p.connCopy(connBackend, conn, errChan) err = <-errChan if err != nil { - log.Errorf("Error during connection: %v", err) + log.WithoutContext().Errorf("Error during connection: %v", err) } + + <-errChan } -func connCopy(dst, src net.Conn, errCh chan error) { +func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) { _, err := io.Copy(dst, src) errCh <- err + + errClose := dst.CloseWrite() + if errClose != nil { + log.WithoutContext().Errorf("Error while terminating connection: %v", errClose) + } + + if p.terminationDelay >= 0 { + err := dst.SetReadDeadline(time.Now().Add(p.terminationDelay)) + if err != nil { + log.WithoutContext().Errorf("Error while setting deadline: %v", err) + } + } } diff --git a/pkg/tcp/proxy_test.go b/pkg/tcp/proxy_test.go new file mode 100644 index 000000000..323a395ff --- /dev/null +++ b/pkg/tcp/proxy_test.go @@ -0,0 +1,81 @@ +package tcp + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func fakeRedis(t *testing.T, listener net.Listener) { + for { + conn, err := listener.Accept() + fmt.Println("Accept on server") + require.NoError(t, err) + for { + withErr := false + buf := make([]byte, 64) + if _, err := conn.Read(buf); err != nil { + withErr = true + } + + if string(buf[:4]) == "ping" { + time.Sleep(time.Millisecond * 1) + if _, err := conn.Write([]byte("PONG")); err != nil { + conn.Close() + return + } + } + if withErr { + conn.Close() + return + } + } + } +} + +func TestCloseWrite(t *testing.T) { + backendListener, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + go fakeRedis(t, backendListener) + _, port, err := net.SplitHostPort(backendListener.Addr().String()) + require.NoError(t, err) + + proxy, err := NewProxy(":"+port, 10*time.Millisecond) + require.NoError(t, err) + + proxyListener, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + go func() { + for { + conn, err := proxyListener.Accept() + require.NoError(t, err) + proxy.ServeTCP(conn.(*net.TCPConn)) + } + }() + + _, port, err = net.SplitHostPort(proxyListener.Addr().String()) + require.NoError(t, err) + + conn, err := net.Dial("tcp", ":"+port) + require.NoError(t, err) + + _, err = conn.Write([]byte("ping\n")) + require.NoError(t, err) + + err = conn.(*net.TCPConn).CloseWrite() + require.NoError(t, err) + + var buf []byte + buffer := bytes.NewBuffer(buf) + n, err := io.Copy(buffer, conn) + require.NoError(t, err) + require.Equal(t, int64(4), n) + require.Equal(t, "PONG", buffer.String()) +} diff --git a/pkg/tcp/router.go b/pkg/tcp/router.go index 7dc379373..f8a7348b0 100644 --- a/pkg/tcp/router.go +++ b/pkg/tcp/router.go @@ -25,7 +25,7 @@ type Router struct { } // ServeTCP forwards the connection to the right TCP/HTTP handler -func (r *Router) ServeTCP(conn net.Conn) { +func (r *Router) ServeTCP(conn WriteCloser) { // FIXME -- Check if ProxyProtocol changes the first bytes of the request if r.catchAllNoTLS != nil && len(r.routingTable) == 0 && r.httpsHandler == nil { @@ -99,11 +99,11 @@ func (r *Router) AddCatchAllNoTLS(handler Handler) { } // GetConn creates a connection proxy with a peeked string -func (r *Router) GetConn(conn net.Conn, peeked string) net.Conn { +func (r *Router) GetConn(conn WriteCloser, peeked string) WriteCloser { // FIXME should it really be on Router ? conn = &Conn{ - Peeked: []byte(peeked), - Conn: conn, + Peeked: []byte(peeked), + WriteCloser: conn, } return conn } @@ -157,7 +157,7 @@ type Conn struct { // It can be type asserted against *net.TCPConn or other types // as needed. It should not be read from directly unless // Peeked is nil. - net.Conn + WriteCloser } // Read reads bytes from the connection (using the buffer prior to actually reading) @@ -170,7 +170,7 @@ func (c *Conn) Read(p []byte) (n int, err error) { } return n, nil } - return c.Conn.Read(p) + return c.WriteCloser.Read(p) } // clientHelloServerName returns the SNI server name inside the TLS ClientHello, diff --git a/pkg/tcp/rr_load_balancer.go b/pkg/tcp/rr_load_balancer.go index 6187d5417..92be94f54 100644 --- a/pkg/tcp/rr_load_balancer.go +++ b/pkg/tcp/rr_load_balancer.go @@ -1,7 +1,6 @@ package tcp import ( - "net" "sync" "github.com/containous/traefik/v2/pkg/log" @@ -20,7 +19,7 @@ func NewRRLoadBalancer() *RRLoadBalancer { } // ServeTCP forwards the connection to the right service -func (r *RRLoadBalancer) ServeTCP(conn net.Conn) { +func (r *RRLoadBalancer) ServeTCP(conn WriteCloser) { if len(r.servers) == 0 { log.WithoutContext().Error("no available server") return diff --git a/pkg/tcp/switcher.go b/pkg/tcp/switcher.go index 330fc993f..16c4c0cbf 100644 --- a/pkg/tcp/switcher.go +++ b/pkg/tcp/switcher.go @@ -1,8 +1,6 @@ package tcp import ( - "net" - "github.com/containous/traefik/v2/pkg/safe" ) @@ -12,7 +10,7 @@ type HandlerSwitcher struct { } // ServeTCP forwards the TCP connection to the current active handler -func (s *HandlerSwitcher) ServeTCP(conn net.Conn) { +func (s *HandlerSwitcher) ServeTCP(conn WriteCloser) { handler := s.router.Get() h, ok := handler.(Handler) if ok { diff --git a/pkg/tcp/tls.go b/pkg/tcp/tls.go index 1bfa3d598..6debaad3d 100644 --- a/pkg/tcp/tls.go +++ b/pkg/tcp/tls.go @@ -2,7 +2,6 @@ package tcp import ( "crypto/tls" - "net" ) // TLSHandler handles TLS connections @@ -12,6 +11,6 @@ type TLSHandler struct { } // ServeTCP terminates the TLS connection -func (t *TLSHandler) ServeTCP(conn net.Conn) { +func (t *TLSHandler) ServeTCP(conn WriteCloser) { t.Next.ServeTCP(tls.Server(conn, t.Config)) }