Improve service name lookup on TCP routers

This commit is contained in:
Douglas De Toni Machado 2020-11-13 08:48:04 -03:00 committed by GitHub
parent 459200dd01
commit 598dcf6b62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 1 deletions

View file

@ -10,8 +10,10 @@ import (
// Proxy forwards a TCP request to a TCP service. // Proxy forwards a TCP request to a TCP service.
type Proxy struct { type Proxy struct {
address string
target *net.TCPAddr target *net.TCPAddr
terminationDelay time.Duration terminationDelay time.Duration
refreshTarget bool
} }
// NewProxy creates a new Proxy. // NewProxy creates a new Proxy.
@ -21,7 +23,18 @@ func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) {
return nil, err return nil, err
} }
return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil // enable the refresh of the target only if the address in an IP
refreshTarget := false
if host, _, err := net.SplitHostPort(address); err == nil && net.ParseIP(host) == nil {
refreshTarget = true
}
return &Proxy{
address: address,
target: tcpAddr,
refreshTarget: refreshTarget,
terminationDelay: terminationDelay,
}, nil
} }
// ServeTCP forwards the connection to a service. // ServeTCP forwards the connection to a service.
@ -31,6 +44,15 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
// needed because of e.g. server.trackedConnection // needed because of e.g. server.trackedConnection
defer conn.Close() defer conn.Close()
if p.refreshTarget {
tcpAddr, err := net.ResolveTCPAddr("tcp", p.address)
if err != nil {
log.Errorf("Error resolving tcp address: %v", err)
return
}
p.target = tcpAddr
}
connBackend, err := net.DialTCP("tcp", nil, p.target) connBackend, err := net.DialTCP("tcp", nil, p.target)
if err != nil { if err != nil {
log.Errorf("Error while connection to backend: %v", err) log.Errorf("Error while connection to backend: %v", err)

View file

@ -5,9 +5,11 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -79,3 +81,74 @@ func TestCloseWrite(t *testing.T) {
require.Equal(t, int64(4), n) require.Equal(t, int64(4), n)
require.Equal(t, "PONG", buffer.String()) require.Equal(t, "PONG", buffer.String())
} }
func TestLookupAddress(t *testing.T) {
testCases := []struct {
desc string
address string
expectSame assert.ComparisonAssertionFunc
}{
{
desc: "IP doesn't need refresh",
address: "8.8.4.4:53",
expectSame: assert.Same,
},
{
desc: "Hostname needs refresh",
address: "dns.google:53",
expectSame: assert.NotSame,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
proxy, err := NewProxy(test.address, 10*time.Millisecond)
require.NoError(t, err)
require.NotNil(t, proxy.target)
proxyListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
var wg sync.WaitGroup
go func(wg *sync.WaitGroup) {
for {
conn, err := proxyListener.Accept()
require.NoError(t, err)
proxy.ServeTCP(conn.(*net.TCPConn))
wg.Done()
}
}(&wg)
var lastTarget *net.TCPAddr
for i := 0; i < 3; i++ {
wg.Add(1)
conn, err := net.Dial("tcp", proxyListener.Addr().String())
require.NoError(t, err)
_, err = conn.Write([]byte("ping\n"))
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
wg.Wait()
assert.NotNil(t, proxy.target)
if lastTarget != nil {
test.expectSame(t, lastTarget, proxy.target)
}
lastTarget = proxy.target
}
})
}
}