Fix initial tcp lookup when address is not available

This commit is contained in:
Douglas De Toni Machado 2022-05-19 11:40:09 -03:00 committed by GitHub
parent d134a993d0
commit 575d4ab431
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 25 deletions

View file

@ -14,33 +14,31 @@ import (
// Proxy forwards a TCP request to a TCP service. // Proxy forwards a TCP request to a TCP service.
type Proxy struct { type Proxy struct {
address string address string
target *net.TCPAddr tcpAddr *net.TCPAddr
terminationDelay time.Duration terminationDelay time.Duration
proxyProtocol *dynamic.ProxyProtocol proxyProtocol *dynamic.ProxyProtocol
refreshTarget bool
} }
// NewProxy creates a new Proxy. // NewProxy creates a new Proxy.
func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dynamic.ProxyProtocol) (*Proxy, error) { func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dynamic.ProxyProtocol) (*Proxy, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
if proxyProtocol != nil && (proxyProtocol.Version < 1 || proxyProtocol.Version > 2) { if proxyProtocol != nil && (proxyProtocol.Version < 1 || proxyProtocol.Version > 2) {
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version) return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
} }
// enable the refresh of the target only if the address in not an IP // Creates the tcpAddr only for IP based addresses,
refreshTarget := false // because there is no need to resolve the name on every new connection,
if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) == nil { // and building it should happen once.
refreshTarget = true var tcpAddr *net.TCPAddr
if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) != nil {
tcpAddr, err = net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
} }
return &Proxy{ return &Proxy{
address: address, address: address,
target: tcpAddr, tcpAddr: tcpAddr,
refreshTarget: refreshTarget,
terminationDelay: terminationDelay, terminationDelay: terminationDelay,
proxyProtocol: proxyProtocol, proxyProtocol: proxyProtocol,
}, nil }, nil
@ -83,10 +81,14 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
} }
func (p Proxy) dialBackend() (*net.TCPConn, error) { func (p Proxy) dialBackend() (*net.TCPConn, error) {
if !p.refreshTarget { // Dial using directly the TCPAddr for IP based addresses.
return net.DialTCP("tcp", nil, p.target) if p.tcpAddr != nil {
return net.DialTCP("tcp", nil, p.tcpAddr)
} }
log.WithoutContext().Debugf("Dial with lookup to address %s", p.address)
// Dial with DNS lookup for host based addresses.
conn, err := net.Dial("tcp", p.address) conn, err := net.Dial("tcp", p.address)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -176,16 +176,20 @@ func TestLookupAddress(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
address string address string
expectRefresh bool expectAddr assert.ComparisonAssertionFunc
expectRefresh assert.ValueAssertionFunc
}{ }{
{ {
desc: "IP doesn't need refresh", desc: "IP doesn't need refresh",
address: "8.8.4.4:53", address: "8.8.4.4:53",
expectAddr: assert.Equal,
expectRefresh: assert.NotNil,
}, },
{ {
desc: "Hostname needs refresh", desc: "Hostname needs refresh",
address: "dns.google:53", address: "dns.google:53",
expectRefresh: true, expectAddr: assert.NotEqual,
expectRefresh: assert.Nil,
}, },
} }
@ -197,16 +201,12 @@ func TestLookupAddress(t *testing.T) {
proxy, err := NewProxy(test.address, 10*time.Millisecond, nil) proxy, err := NewProxy(test.address, 10*time.Millisecond, nil)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, proxy.target) test.expectRefresh(t, proxy.tcpAddr)
conn, err := proxy.dialBackend() conn, err := proxy.dialBackend()
require.NoError(t, err) require.NoError(t, err)
if test.expectRefresh { test.expectAddr(t, test.address, conn.RemoteAddr().String())
assert.NotEqual(t, test.address, conn.RemoteAddr().String())
} else {
assert.Equal(t, test.address, conn.RemoteAddr().String())
}
}) })
} }
} }