Avoid overwriting already received UDP messages
This commit is contained in:
parent
fb90a7889a
commit
0d902671e5
2 changed files with 65 additions and 2 deletions
|
@ -128,9 +128,11 @@ func (l *Listener) Shutdown(graceTimeout time.Duration) error {
|
||||||
// we find that session, and otherwise we create a new one.
|
// we find that session, and otherwise we create a new one.
|
||||||
// We then send the data the session's readLoop.
|
// We then send the data the session's readLoop.
|
||||||
func (l *Listener) readLoop() {
|
func (l *Listener) readLoop() {
|
||||||
buf := make([]byte, receiveMTU)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
// Allocating a new buffer for every read avoids
|
||||||
|
// overwriting data in c.msgs in case the next packet is received
|
||||||
|
// before c.msgs is emptied via Read()
|
||||||
|
buf := make([]byte, receiveMTU)
|
||||||
n, raddr, err := l.pConn.ReadFrom(buf)
|
n, raddr, err := l.pConn.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -10,6 +10,67 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestConsecutiveWrites(t *testing.T) {
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", ":0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ln, err := Listen("udp", addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := ln.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err == errClosedListener {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
b := make([]byte, 2048)
|
||||||
|
b2 := make([]byte, 2048)
|
||||||
|
var n int
|
||||||
|
var n2 int
|
||||||
|
|
||||||
|
n, err = conn.Read(b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Wait to make sure that the second packet is received
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
n2, err = conn.Read(b2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = conn.Write(b[:n])
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = conn.Write(b2[:n2])
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
udpConn, err := net.Dial("udp", ln.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Send multiple packets of different content and length consecutively
|
||||||
|
// Read back packets afterwards and make sure that content matches
|
||||||
|
// This checks if any buffers are overwritten while the receiver is enqueuing multiple packets
|
||||||
|
b := make([]byte, 2048)
|
||||||
|
var n int
|
||||||
|
_, err = udpConn.Write([]byte("TESTLONG0"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = udpConn.Write([]byte("1TEST"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
n, err = udpConn.Read(b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "TESTLONG0", string(b[:n]))
|
||||||
|
n, err = udpConn.Read(b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "1TEST", string(b[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
func TestListenNotBlocking(t *testing.T) {
|
func TestListenNotBlocking(t *testing.T) {
|
||||||
addr, err := net.ResolveUDPAddr("udp", ":0")
|
addr, err := net.ResolveUDPAddr("udp", ":0")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue