From 8947f85ddd791c3113bcc7419a9924d3f3d692c0 Mon Sep 17 00:00:00 2001 From: HMH Date: Tue, 23 Mar 2021 10:24:03 +0000 Subject: [PATCH] Improve host name resolution for TCP proxy --- pkg/tcp/proxy.go | 30 ++++++++++++---------- pkg/tcp/proxy_test.go | 59 ++++++++++--------------------------------- 2 files changed, 30 insertions(+), 59 deletions(-) diff --git a/pkg/tcp/proxy.go b/pkg/tcp/proxy.go index 552fe0f3b..06dc2884a 100644 --- a/pkg/tcp/proxy.go +++ b/pkg/tcp/proxy.go @@ -31,7 +31,7 @@ func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dyn return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version) } - // enable the refresh of the target only if the address in an IP + // enable the refresh of the target only if the address in not an IP refreshTarget := false if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) == nil { refreshTarget = true @@ -48,23 +48,14 @@ func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dyn // ServeTCP forwards the connection to a service. func (p *Proxy) ServeTCP(conn WriteCloser) { - log.Debugf("Handling connection from %s", conn.RemoteAddr()) + log.WithoutContext().Debugf("Handling connection from %s", conn.RemoteAddr()) // needed because of e.g. server.trackedConnection defer conn.Close() - if p.refreshTarget { - tcpAddr, err := net.ResolveTCPAddr("tcp", p.address) - if err != nil { - log.Errorf("Error resolving tcp address: %v", err) - return - } - p.target = tcpAddr - } - - connBackend, err := net.DialTCP("tcp", nil, p.target) + connBackend, err := p.dialBackend() if err != nil { - log.Errorf("Error while connection to backend: %v", err) + log.WithoutContext().Errorf("Error while connecting to backend: %v", err) return } @@ -91,6 +82,19 @@ func (p *Proxy) ServeTCP(conn WriteCloser) { <-errChan } +func (p Proxy) dialBackend() (*net.TCPConn, error) { + if !p.refreshTarget { + return net.DialTCP("tcp", nil, p.target) + } + + conn, err := net.Dial("tcp", p.address) + if err != nil { + return nil, err + } + + return conn.(*net.TCPConn), nil +} + func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) { _, err := io.Copy(dst, src) errCh <- err diff --git a/pkg/tcp/proxy_test.go b/pkg/tcp/proxy_test.go index 11fdb47c4..e70b36796 100644 --- a/pkg/tcp/proxy_test.go +++ b/pkg/tcp/proxy_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net" - "sync" "testing" "time" @@ -175,19 +174,18 @@ func TestProxyProtocol(t *testing.T) { func TestLookupAddress(t *testing.T) { testCases := []struct { - desc string - address string - expectSame assert.ComparisonAssertionFunc + desc string + address string + expectRefresh bool }{ { - desc: "IP doesn't need refresh", - address: "8.8.4.4:53", - expectSame: assert.Same, + desc: "IP doesn't need refresh", + address: "8.8.4.4:53", }, { - desc: "Hostname needs refresh", - address: "dns.google:53", - expectSame: assert.NotSame, + desc: "Hostname needs refresh", + address: "dns.google:53", + expectRefresh: true, }, } @@ -201,44 +199,13 @@ func TestLookupAddress(t *testing.T) { require.NotNil(t, proxy.target) - proxyListener, err := net.Listen("tcp", ":0") + conn, err := proxy.dialBackend() require.NoError(t, err) - var wg sync.WaitGroup - go func(wg *sync.WaitGroup) { - for { - conn, err := proxyListener.Accept() - require.NoError(t, err) - - proxy.ServeTCP(conn.(*net.TCPConn)) - - wg.Done() - } - }(&wg) - - var lastTarget *net.TCPAddr - - for i := 0; i < 3; i++ { - wg.Add(1) - - conn, err := net.Dial("tcp", proxyListener.Addr().String()) - require.NoError(t, err) - - _, err = conn.Write([]byte("ping\n")) - require.NoError(t, err) - - err = conn.Close() - require.NoError(t, err) - - wg.Wait() - - assert.NotNil(t, proxy.target) - - if lastTarget != nil { - test.expectSame(t, lastTarget, proxy.target) - } - - lastTarget = proxy.target + if test.expectRefresh { + assert.NotEqual(t, test.address, conn.RemoteAddr().String()) + } else { + assert.Equal(t, test.address, conn.RemoteAddr().String()) } }) }