Avoid overwriting already received UDP messages

This commit is contained in:
cbachert 2020-06-08 17:12:04 +01:00 committed by GitHub
parent fb90a7889a
commit 0d902671e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 2 deletions

View file

@ -128,9 +128,11 @@ func (l *Listener) Shutdown(graceTimeout time.Duration) error {
// we find that session, and otherwise we create a new one.
// We then send the data the session's readLoop.
func (l *Listener) readLoop() {
buf := make([]byte, receiveMTU)
for {
// Allocating a new buffer for every read avoids
// overwriting data in c.msgs in case the next packet is received
// before c.msgs is emptied via Read()
buf := make([]byte, receiveMTU)
n, raddr, err := l.pConn.ReadFrom(buf)
if err != nil {
return

View file

@ -10,6 +10,67 @@ import (
"github.com/stretchr/testify/require"
)
func TestConsecutiveWrites(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", ":0")
require.NoError(t, err)
ln, err := Listen("udp", addr)
require.NoError(t, err)
defer func() {
err := ln.Close()
require.NoError(t, err)
}()
go func() {
for {
conn, err := ln.Accept()
if err == errClosedListener {
return
}
require.NoError(t, err)
go func() {
b := make([]byte, 2048)
b2 := make([]byte, 2048)
var n int
var n2 int
n, err = conn.Read(b)
require.NoError(t, err)
// Wait to make sure that the second packet is received
time.Sleep(10 * time.Millisecond)
n2, err = conn.Read(b2)
require.NoError(t, err)
_, err = conn.Write(b[:n])
require.NoError(t, err)
_, err = conn.Write(b2[:n2])
require.NoError(t, err)
}()
}
}()
udpConn, err := net.Dial("udp", ln.Addr().String())
require.NoError(t, err)
// Send multiple packets of different content and length consecutively
// Read back packets afterwards and make sure that content matches
// This checks if any buffers are overwritten while the receiver is enqueuing multiple packets
b := make([]byte, 2048)
var n int
_, err = udpConn.Write([]byte("TESTLONG0"))
require.NoError(t, err)
_, err = udpConn.Write([]byte("1TEST"))
require.NoError(t, err)
n, err = udpConn.Read(b)
require.NoError(t, err)
require.Equal(t, "TESTLONG0", string(b[:n]))
n, err = udpConn.Read(b)
require.NoError(t, err)
require.Equal(t, "1TEST", string(b[:n]))
}
func TestListenNotBlocking(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", ":0")