From 0d902671e509f52390163035e89c89f89296150e Mon Sep 17 00:00:00 2001 From: cbachert Date: Mon, 8 Jun 2020 17:12:04 +0100 Subject: [PATCH] Avoid overwriting already received UDP messages --- pkg/udp/conn.go | 6 +++-- pkg/udp/conn_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/pkg/udp/conn.go b/pkg/udp/conn.go index 516212ec2..9aef31f2a 100644 --- a/pkg/udp/conn.go +++ b/pkg/udp/conn.go @@ -128,9 +128,11 @@ func (l *Listener) Shutdown(graceTimeout time.Duration) error { // we find that session, and otherwise we create a new one. // We then send the data the session's readLoop. func (l *Listener) readLoop() { - buf := make([]byte, receiveMTU) - for { + // Allocating a new buffer for every read avoids + // overwriting data in c.msgs in case the next packet is received + // before c.msgs is emptied via Read() + buf := make([]byte, receiveMTU) n, raddr, err := l.pConn.ReadFrom(buf) if err != nil { return diff --git a/pkg/udp/conn_test.go b/pkg/udp/conn_test.go index 91d3fe571..62c35d04a 100644 --- a/pkg/udp/conn_test.go +++ b/pkg/udp/conn_test.go @@ -10,6 +10,67 @@ import ( "github.com/stretchr/testify/require" ) +func TestConsecutiveWrites(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", ":0") + require.NoError(t, err) + + ln, err := Listen("udp", addr) + require.NoError(t, err) + defer func() { + err := ln.Close() + require.NoError(t, err) + }() + + go func() { + for { + conn, err := ln.Accept() + if err == errClosedListener { + return + } + require.NoError(t, err) + + go func() { + b := make([]byte, 2048) + b2 := make([]byte, 2048) + var n int + var n2 int + + n, err = conn.Read(b) + require.NoError(t, err) + // Wait to make sure that the second packet is received + time.Sleep(10 * time.Millisecond) + n2, err = conn.Read(b2) + require.NoError(t, err) + + _, err = conn.Write(b[:n]) + require.NoError(t, err) + _, err = conn.Write(b2[:n2]) + require.NoError(t, err) + }() + } + }() + + udpConn, err := net.Dial("udp", ln.Addr().String()) + require.NoError(t, err) + + // Send multiple packets of different content and length consecutively + // Read back packets afterwards and make sure that content matches + // This checks if any buffers are overwritten while the receiver is enqueuing multiple packets + b := make([]byte, 2048) + var n int + _, err = udpConn.Write([]byte("TESTLONG0")) + require.NoError(t, err) + _, err = udpConn.Write([]byte("1TEST")) + require.NoError(t, err) + + n, err = udpConn.Read(b) + require.NoError(t, err) + require.Equal(t, "TESTLONG0", string(b[:n])) + n, err = udpConn.Read(b) + require.NoError(t, err) + require.Equal(t, "1TEST", string(b[:n])) +} + func TestListenNotBlocking(t *testing.T) { addr, err := net.ResolveUDPAddr("udp", ":0")