diff --git a/pkg/udp/conn.go b/pkg/udp/conn.go index d8b38f0e9..5753a9c13 100644 --- a/pkg/udp/conn.go +++ b/pkg/udp/conn.go @@ -8,7 +8,8 @@ import ( "time" ) -const receiveMTU = 8192 +// maxDatagramSize is the maximum size of a UDP datagram. +const maxDatagramSize = 65535 const closeRetryInterval = 500 * time.Millisecond @@ -135,7 +136,8 @@ func (l *Listener) readLoop() { // Allocating a new buffer for every read avoids // overwriting data in c.msgs in case the next packet is received // before c.msgs is emptied via Read() - buf := make([]byte, receiveMTU) + buf := make([]byte, maxDatagramSize) + n, raddr, err := l.pConn.ReadFrom(buf) if err != nil { return @@ -144,6 +146,7 @@ func (l *Listener) readLoop() { if err != nil { continue } + select { case conn.receiveCh <- buf[:n]: case <-conn.doneCh: @@ -249,7 +252,9 @@ func (c *Conn) readLoop() { } } -// Read implements io.Reader for a Conn. +// Read reads up to len(p) bytes into p from the connection. +// Each call corresponds to at most one datagram. +// If p is smaller than the datagram, the extra bytes will be discarded. func (c *Conn) Read(p []byte) (int, error) { select { case c.readCh <- p: @@ -258,22 +263,21 @@ func (c *Conn) Read(p []byte) (int, error) { c.lastActivity = time.Now() c.muActivity.Unlock() return n, nil + case <-c.doneCh: return 0, io.EOF } } -// Write implements io.Writer for a Conn. +// Write writes len(p) bytes from p to the underlying connection. +// Each call sends at most one datagram. +// It is an error to send a message larger than the system's max UDP datagram size. func (c *Conn) Write(p []byte) (n int, err error) { - l := c.listener - if l == nil { - return 0, io.EOF - } - c.muActivity.Lock() c.lastActivity = time.Now() c.muActivity.Unlock() - return l.pConn.WriteTo(p, c.rAddr) + + return c.listener.pConn.WriteTo(p, c.rAddr) } func (c *Conn) close() { diff --git a/pkg/udp/conn_test.go b/pkg/udp/conn_test.go index dce924bc9..6351e3309 100644 --- a/pkg/udp/conn_test.go +++ b/pkg/udp/conn_test.go @@ -1,9 +1,11 @@ package udp import ( + "crypto/rand" "errors" "io" "net" + "runtime" "testing" "time" @@ -317,6 +319,61 @@ func TestShutdown(t *testing.T) { } } +func TestReadLoopMaxDataSize(t *testing.T) { + if runtime.GOOS == "darwin" { + // sudo sysctl -w net.inet.udp.maxdgram=65507 + t.Skip("Skip test on darwin as the maximum dgram size is set to 9216 bytes by default") + } + + // Theoretical maximum size of data in a UDP datagram. + // 65535 − 8 (UDP header) − 20 (IP header). + dataSize := 65507 + + doneCh := make(chan struct{}) + + addr, err := net.ResolveUDPAddr("udp", ":0") + require.NoError(t, err) + + l, err := Listen("udp", addr, 3*time.Second) + require.NoError(t, err) + + defer func() { + err := l.Close() + require.NoError(t, err) + }() + + go func() { + defer close(doneCh) + + conn, err := l.Accept() + require.NoError(t, err) + + buffer := make([]byte, dataSize) + + n, err := conn.Read(buffer) + require.NoError(t, err) + + assert.Equal(t, dataSize, n) + }() + + c, err := net.Dial("udp", l.Addr().String()) + require.NoError(t, err) + + data := make([]byte, dataSize) + + _, err = rand.Read(data) + require.NoError(t, err) + + _, err = c.Write(data) + require.NoError(t, err) + + select { + case <-doneCh: + case <-time.Tick(5 * time.Second): + t.Fatal("Timeout waiting for datagram read") + } +} + // requireEcho tests that the conn session is live and functional, // by writing data through it, and expecting the same data as a response when reading on it. // It fatals if the read blocks longer than timeout, diff --git a/pkg/udp/proxy.go b/pkg/udp/proxy.go index 7e822cbf6..a69aafd93 100644 --- a/pkg/udp/proxy.go +++ b/pkg/udp/proxy.go @@ -20,14 +20,14 @@ func NewProxy(address string) (*Proxy, error) { // ServeUDP implements the Handler interface. func (p *Proxy) ServeUDP(conn *Conn) { - log.Debugf("Handling connection from %s", conn.rAddr) + log.WithoutContext().Debugf("Handling connection from %s", conn.rAddr) // needed because of e.g. server.trackedConnection defer conn.Close() connBackend, err := net.Dial("udp", p.target) if err != nil { - log.Errorf("Error while connecting to backend: %v", err) + log.WithoutContext().Errorf("Error while connecting to backend: %v", err) return } @@ -35,8 +35,8 @@ func (p *Proxy) ServeUDP(conn *Conn) { defer connBackend.Close() errChan := make(chan error) - go p.connCopy(conn, connBackend, errChan) - go p.connCopy(connBackend, conn, errChan) + go connCopy(conn, connBackend, errChan) + go connCopy(connBackend, conn, errChan) err = <-errChan if err != nil { @@ -46,8 +46,12 @@ func (p *Proxy) ServeUDP(conn *Conn) { <-errChan } -func (p Proxy) connCopy(dst io.WriteCloser, src io.Reader, errCh chan error) { - _, err := io.Copy(dst, src) +func connCopy(dst io.WriteCloser, src io.Reader, errCh chan error) { + // The buffer is initialized to the maximum UDP datagram size, + // to make sure that the whole UDP datagram is read or written atomically (no data is discarded). + buffer := make([]byte, maxDatagramSize) + + _, err := io.CopyBuffer(dst, src, buffer) errCh <- err if err := dst.Close(); err != nil { diff --git a/pkg/udp/proxy_test.go b/pkg/udp/proxy_test.go index 120cbc457..b3ce2ec2c 100644 --- a/pkg/udp/proxy_test.go +++ b/pkg/udp/proxy_test.go @@ -1,7 +1,9 @@ package udp import ( + "crypto/rand" "net" + "runtime" "testing" "time" @@ -9,13 +11,14 @@ import ( "github.com/stretchr/testify/require" ) -func TestUDPProxy(t *testing.T) { +func TestProxy_ServeUDP(t *testing.T) { backendAddr := ":8081" - go newServer(t, ":8081", HandlerFunc(func(conn *Conn) { + go newServer(t, backendAddr, HandlerFunc(func(conn *Conn) { for { b := make([]byte, 1024*1024) n, err := conn.Read(b) require.NoError(t, err) + _, err = conn.Write(b[:n]) require.NoError(t, err) } @@ -28,6 +31,7 @@ func TestUDPProxy(t *testing.T) { go newServer(t, proxyAddr, proxy) time.Sleep(time.Second) + udpConn, err := net.Dial("udp", proxyAddr) require.NoError(t, err) @@ -37,9 +41,58 @@ func TestUDPProxy(t *testing.T) { b := make([]byte, 1024*1024) n, err := udpConn.Read(b) require.NoError(t, err) + assert.Equal(t, "DATAWRITE", string(b[:n])) } +func TestProxy_ServeUDP_MaxDataSize(t *testing.T) { + if runtime.GOOS == "darwin" { + // sudo sysctl -w net.inet.udp.maxdgram=65507 + t.Skip("Skip test on darwin as the maximum dgram size is set to 9216 bytes by default") + } + + // Theoretical maximum size of data in a UDP datagram. + // 65535 − 8 (UDP header) − 20 (IP header). + dataSize := 65507 + + backendAddr := ":8083" + go newServer(t, backendAddr, HandlerFunc(func(conn *Conn) { + buffer := make([]byte, dataSize) + + n, err := conn.Read(buffer) + require.NoError(t, err) + + _, err = conn.Write(buffer[:n]) + require.NoError(t, err) + })) + + proxy, err := NewProxy(backendAddr) + require.NoError(t, err) + + proxyAddr := ":8082" + go newServer(t, proxyAddr, proxy) + + time.Sleep(time.Second) + + udpConn, err := net.Dial("udp", proxyAddr) + require.NoError(t, err) + + want := make([]byte, dataSize) + + _, err = rand.Read(want) + require.NoError(t, err) + + _, err = udpConn.Write(want) + require.NoError(t, err) + + got := make([]byte, dataSize) + + _, err = udpConn.Read(got) + require.NoError(t, err) + + assert.Equal(t, want, got) +} + func newServer(t *testing.T, addr string, handler Handler) { t.Helper() @@ -52,6 +105,7 @@ func newServer(t *testing.T, addr string, handler Handler) { for { conn, err := listener.Accept() require.NoError(t, err) + go handler.ServeUDP(conn) } }