191 lines
4.1 KiB
Go
191 lines
4.1 KiB
Go
// Package udp implements UDP test helpers. It lets you assert that certain
|
|
// strings must or must not be sent to a given local UDP listener.
|
|
package udp
|
|
|
|
import (
|
|
"net"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
addr *string
|
|
listener *net.UDPConn
|
|
Timeout time.Duration = time.Millisecond
|
|
)
|
|
|
|
type fn func()
|
|
|
|
// SetAddr sets the UDP port that will be listened on.
|
|
func SetAddr(a string) {
|
|
addr = &a
|
|
}
|
|
|
|
func start(t *testing.T) {
|
|
resAddr, err := net.ResolveUDPAddr("udp", *addr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
listener, err = net.ListenUDP("udp", resAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func stop(t *testing.T) {
|
|
if err := listener.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func getMessage(t *testing.T, body fn) string {
|
|
start(t)
|
|
defer stop(t)
|
|
|
|
body()
|
|
|
|
message := make([]byte, 1024*32)
|
|
var bufLen int
|
|
for {
|
|
listener.SetReadDeadline(time.Now().Add(Timeout))
|
|
n, _, _ := listener.ReadFrom(message[bufLen:])
|
|
if n == 0 {
|
|
break
|
|
} else {
|
|
bufLen += n
|
|
}
|
|
}
|
|
|
|
return string(message[0:bufLen])
|
|
}
|
|
|
|
func get(t *testing.T, match string, body fn) (got string, equals bool, contains bool) {
|
|
got = getMessage(t, body)
|
|
equals = got == match
|
|
contains = strings.Contains(got, match)
|
|
return got, equals, contains
|
|
}
|
|
|
|
func printLocation(t *testing.T) {
|
|
_, file, line, _ := runtime.Caller(2)
|
|
t.Errorf("At: %s:%d", file, line)
|
|
}
|
|
|
|
// ShouldReceiveOnly will fire a test error if the given function doesn't send
|
|
// exactly the given string over UDP.
|
|
func ShouldReceiveOnly(t *testing.T, expected string, body fn) {
|
|
got, equals, _ := get(t, expected, body)
|
|
if !equals {
|
|
printLocation(t)
|
|
t.Errorf("Expected: %#v", expected)
|
|
t.Errorf("But got: %#v", got)
|
|
}
|
|
}
|
|
|
|
// ShouldNotReceiveOnly will fire a test error if the given function sends
|
|
// exactly the given string over UDP.
|
|
func ShouldNotReceiveOnly(t *testing.T, notExpected string, body fn) {
|
|
_, equals, _ := get(t, notExpected, body)
|
|
if equals {
|
|
printLocation(t)
|
|
t.Errorf("Expected not to get: %#v", notExpected)
|
|
}
|
|
}
|
|
|
|
// ShouldReceive will fire a test error if the given function doesn't send the
|
|
// given string over UDP.
|
|
func ShouldReceive(t *testing.T, expected string, body fn) {
|
|
got, _, contains := get(t, expected, body)
|
|
if !contains {
|
|
printLocation(t)
|
|
t.Errorf("Expected to find: %#v", expected)
|
|
t.Errorf("But got: %#v", got)
|
|
}
|
|
}
|
|
|
|
// ShouldNotReceive will fire a test error if the given function sends the
|
|
// given string over UDP.
|
|
func ShouldNotReceive(t *testing.T, expected string, body fn) {
|
|
got, _, contains := get(t, expected, body)
|
|
if contains {
|
|
printLocation(t)
|
|
t.Errorf("Expected not to find: %#v", expected)
|
|
t.Errorf("But got: %#v", got)
|
|
}
|
|
}
|
|
|
|
// ShouldReceiveAll will fire a test error unless all of the given strings are
|
|
// sent over UDP.
|
|
func ShouldReceiveAll(t *testing.T, expected []string, body fn) {
|
|
got := getMessage(t, body)
|
|
failed := false
|
|
|
|
for _, str := range expected {
|
|
if !strings.Contains(got, str) {
|
|
if !failed {
|
|
printLocation(t)
|
|
failed = true
|
|
}
|
|
t.Errorf("Expected to find: %#v", str)
|
|
}
|
|
}
|
|
|
|
if failed {
|
|
t.Errorf("But got: %#v", got)
|
|
}
|
|
}
|
|
|
|
// ShouldNotReceiveAny will fire a test error if any of the given strings are
|
|
// sent over UDP.
|
|
func ShouldNotReceiveAny(t *testing.T, unexpected []string, body fn) {
|
|
got := getMessage(t, body)
|
|
failed := false
|
|
|
|
for _, str := range unexpected {
|
|
if strings.Contains(got, str) {
|
|
if !failed {
|
|
printLocation(t)
|
|
failed = true
|
|
}
|
|
t.Errorf("Expected not to find: %#v", str)
|
|
}
|
|
}
|
|
|
|
if failed {
|
|
t.Errorf("But got: %#v", got)
|
|
}
|
|
}
|
|
|
|
func ShouldReceiveAllAndNotReceiveAny(t *testing.T, expected []string, unexpected []string, body fn) {
|
|
got := getMessage(t, body)
|
|
failed := false
|
|
|
|
for _, str := range expected {
|
|
if !strings.Contains(got, str) {
|
|
if !failed {
|
|
printLocation(t)
|
|
failed = true
|
|
}
|
|
t.Errorf("Expected to find: %#v", str)
|
|
}
|
|
}
|
|
for _, str := range unexpected {
|
|
if strings.Contains(got, str) {
|
|
if !failed {
|
|
printLocation(t)
|
|
failed = true
|
|
}
|
|
t.Errorf("Expected not to find: %#v", str)
|
|
}
|
|
}
|
|
|
|
if failed {
|
|
t.Errorf("but got: %#v", got)
|
|
}
|
|
}
|
|
|
|
func ReceiveString(t *testing.T, body fn) string {
|
|
return getMessage(t, body)
|
|
}
|