82 lines
1.6 KiB
Go
82 lines
1.6 KiB
Go
|
package tcp
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func fakeRedis(t *testing.T, listener net.Listener) {
|
||
|
for {
|
||
|
conn, err := listener.Accept()
|
||
|
fmt.Println("Accept on server")
|
||
|
require.NoError(t, err)
|
||
|
for {
|
||
|
withErr := false
|
||
|
buf := make([]byte, 64)
|
||
|
if _, err := conn.Read(buf); err != nil {
|
||
|
withErr = true
|
||
|
}
|
||
|
|
||
|
if string(buf[:4]) == "ping" {
|
||
|
time.Sleep(time.Millisecond * 1)
|
||
|
if _, err := conn.Write([]byte("PONG")); err != nil {
|
||
|
conn.Close()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
if withErr {
|
||
|
conn.Close()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCloseWrite(t *testing.T) {
|
||
|
backendListener, err := net.Listen("tcp", ":0")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
go fakeRedis(t, backendListener)
|
||
|
_, port, err := net.SplitHostPort(backendListener.Addr().String())
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
proxy, err := NewProxy(":"+port, 10*time.Millisecond)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
proxyListener, err := net.Listen("tcp", ":0")
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
go func() {
|
||
|
for {
|
||
|
conn, err := proxyListener.Accept()
|
||
|
require.NoError(t, err)
|
||
|
proxy.ServeTCP(conn.(*net.TCPConn))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
conn, err := net.Dial("tcp", ":"+port)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
_, err = conn.Write([]byte("ping\n"))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
err = conn.(*net.TCPConn).CloseWrite()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
var buf []byte
|
||
|
buffer := bytes.NewBuffer(buf)
|
||
|
n, err := io.Copy(buffer, conn)
|
||
|
require.NoError(t, err)
|
||
|
require.Equal(t, int64(4), n)
|
||
|
require.Equal(t, "PONG", buffer.String())
|
||
|
}
|