Improve host name resolution for TCP proxy
This commit is contained in:
parent
29908098e4
commit
8947f85ddd
2 changed files with 30 additions and 59 deletions
|
@ -31,7 +31,7 @@ func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dyn
|
||||||
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
|
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable the refresh of the target only if the address in an IP
|
// enable the refresh of the target only if the address in not an IP
|
||||||
refreshTarget := false
|
refreshTarget := false
|
||||||
if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) == nil {
|
if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) == nil {
|
||||||
refreshTarget = true
|
refreshTarget = true
|
||||||
|
@ -48,23 +48,14 @@ func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dyn
|
||||||
|
|
||||||
// ServeTCP forwards the connection to a service.
|
// ServeTCP forwards the connection to a service.
|
||||||
func (p *Proxy) ServeTCP(conn WriteCloser) {
|
func (p *Proxy) ServeTCP(conn WriteCloser) {
|
||||||
log.Debugf("Handling connection from %s", conn.RemoteAddr())
|
log.WithoutContext().Debugf("Handling connection from %s", conn.RemoteAddr())
|
||||||
|
|
||||||
// needed because of e.g. server.trackedConnection
|
// needed because of e.g. server.trackedConnection
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
if p.refreshTarget {
|
connBackend, err := p.dialBackend()
|
||||||
tcpAddr, err := net.ResolveTCPAddr("tcp", p.address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error resolving tcp address: %v", err)
|
log.WithoutContext().Errorf("Error while connecting to backend: %v", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
p.target = tcpAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
connBackend, err := net.DialTCP("tcp", nil, p.target)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Error while connection to backend: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,6 +82,19 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
|
||||||
<-errChan
|
<-errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p Proxy) dialBackend() (*net.TCPConn, error) {
|
||||||
|
if !p.refreshTarget {
|
||||||
|
return net.DialTCP("tcp", nil, p.target)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", p.address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.(*net.TCPConn), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) {
|
func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) {
|
||||||
_, err := io.Copy(dst, src)
|
_, err := io.Copy(dst, src)
|
||||||
errCh <- err
|
errCh <- err
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -177,17 +176,16 @@ func TestLookupAddress(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
address string
|
address string
|
||||||
expectSame assert.ComparisonAssertionFunc
|
expectRefresh bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "IP doesn't need refresh",
|
desc: "IP doesn't need refresh",
|
||||||
address: "8.8.4.4:53",
|
address: "8.8.4.4:53",
|
||||||
expectSame: assert.Same,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "Hostname needs refresh",
|
desc: "Hostname needs refresh",
|
||||||
address: "dns.google:53",
|
address: "dns.google:53",
|
||||||
expectSame: assert.NotSame,
|
expectRefresh: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,44 +199,13 @@ func TestLookupAddress(t *testing.T) {
|
||||||
|
|
||||||
require.NotNil(t, proxy.target)
|
require.NotNil(t, proxy.target)
|
||||||
|
|
||||||
proxyListener, err := net.Listen("tcp", ":0")
|
conn, err := proxy.dialBackend()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
if test.expectRefresh {
|
||||||
go func(wg *sync.WaitGroup) {
|
assert.NotEqual(t, test.address, conn.RemoteAddr().String())
|
||||||
for {
|
} else {
|
||||||
conn, err := proxyListener.Accept()
|
assert.Equal(t, test.address, conn.RemoteAddr().String())
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
proxy.ServeTCP(conn.(*net.TCPConn))
|
|
||||||
|
|
||||||
wg.Done()
|
|
||||||
}
|
|
||||||
}(&wg)
|
|
||||||
|
|
||||||
var lastTarget *net.TCPAddr
|
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", proxyListener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping\n"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = conn.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
assert.NotNil(t, proxy.target)
|
|
||||||
|
|
||||||
if lastTarget != nil {
|
|
||||||
test.expectSame(t, lastTarget, proxy.target)
|
|
||||||
}
|
|
||||||
|
|
||||||
lastTarget = proxy.target
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue