diff --git a/pkg/tcp/proxy.go b/pkg/tcp/proxy.go index d3401a874..f66380d49 100644 --- a/pkg/tcp/proxy.go +++ b/pkg/tcp/proxy.go @@ -10,8 +10,10 @@ import ( // Proxy forwards a TCP request to a TCP service. type Proxy struct { + address string target *net.TCPAddr terminationDelay time.Duration + refreshTarget bool } // NewProxy creates a new Proxy. @@ -21,7 +23,18 @@ func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) { return nil, err } - return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil + // enable the refresh of the target only if the address in an IP + refreshTarget := false + if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) == nil { + refreshTarget = true + } + + return &Proxy{ + address: address, + target: tcpAddr, + refreshTarget: refreshTarget, + terminationDelay: terminationDelay, + }, nil } // ServeTCP forwards the connection to a service. @@ -31,6 +44,15 @@ func (p *Proxy) ServeTCP(conn WriteCloser) { // needed because of e.g. server.trackedConnection defer conn.Close() + if p.refreshTarget { + tcpAddr, err := net.ResolveTCPAddr("tcp", p.address) + if err != nil { + log.Errorf("Error resolving tcp address: %v", err) + return + } + p.target = tcpAddr + } + connBackend, err := net.DialTCP("tcp", nil, p.target) if err != nil { log.Errorf("Error while connection to backend: %v", err) diff --git a/pkg/tcp/proxy_test.go b/pkg/tcp/proxy_test.go index aa7dd1eaa..34f42b28d 100644 --- a/pkg/tcp/proxy_test.go +++ b/pkg/tcp/proxy_test.go @@ -5,9 +5,11 @@ import ( "fmt" "io" "net" + "sync" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -79,3 +81,74 @@ func TestCloseWrite(t *testing.T) { require.Equal(t, int64(4), n) require.Equal(t, "PONG", buffer.String()) } + +func TestLookupAddress(t *testing.T) { + testCases := []struct { + desc string + address string + expectSame assert.ComparisonAssertionFunc + }{ + { + desc: "IP doesn't need refresh", + address: "8.8.4.4:53", + expectSame: assert.Same, + }, + { + desc: "Hostname needs refresh", + address: "dns.google:53", + expectSame: assert.NotSame, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + proxy, err := NewProxy(test.address, 10*time.Millisecond) + require.NoError(t, err) + + require.NotNil(t, proxy.target) + + proxyListener, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + var wg sync.WaitGroup + go func(wg *sync.WaitGroup) { + for { + conn, err := proxyListener.Accept() + require.NoError(t, err) + + proxy.ServeTCP(conn.(*net.TCPConn)) + + wg.Done() + } + }(&wg) + + var lastTarget *net.TCPAddr + + for i := 0; i < 3; i++ { + wg.Add(1) + + conn, err := net.Dial("tcp", proxyListener.Addr().String()) + require.NoError(t, err) + + _, err = conn.Write([]byte("ping\n")) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + wg.Wait() + + assert.NotNil(t, proxy.target) + + if lastTarget != nil { + test.expectSame(t, lastTarget, proxy.target) + } + + lastTarget = proxy.target + } + }) + } +}