2019-09-13 15:46:04 +00:00
|
|
|
package tcp
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
)
|
|
|
|
|
|
|
|
func fakeRedis(t *testing.T, listener net.Listener) {
|
|
|
|
for {
|
|
|
|
conn, err := listener.Accept()
|
|
|
|
fmt.Println("Accept on server")
|
|
|
|
require.NoError(t, err)
|
|
|
|
for {
|
|
|
|
withErr := false
|
|
|
|
buf := make([]byte, 64)
|
|
|
|
if _, err := conn.Read(buf); err != nil {
|
|
|
|
withErr = true
|
|
|
|
}
|
|
|
|
|
|
|
|
if string(buf[:4]) == "ping" {
|
2019-11-26 20:38:03 +00:00
|
|
|
time.Sleep(1 * time.Millisecond)
|
2019-09-13 15:46:04 +00:00
|
|
|
if _, err := conn.Write([]byte("PONG")); err != nil {
|
|
|
|
conn.Close()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if withErr {
|
|
|
|
conn.Close()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestCloseWrite(t *testing.T) {
|
|
|
|
backendListener, err := net.Listen("tcp", ":0")
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
go fakeRedis(t, backendListener)
|
|
|
|
_, port, err := net.SplitHostPort(backendListener.Addr().String())
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
proxy, err := NewProxy(":"+port, 10*time.Millisecond)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
proxyListener, err := net.Listen("tcp", ":0")
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
for {
|
|
|
|
conn, err := proxyListener.Accept()
|
|
|
|
require.NoError(t, err)
|
|
|
|
proxy.ServeTCP(conn.(*net.TCPConn))
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
conn, err := net.Dial("tcp", ":"+port)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
_, err = conn.Write([]byte("ping\n"))
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
err = conn.(*net.TCPConn).CloseWrite()
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
var buf []byte
|
|
|
|
buffer := bytes.NewBuffer(buf)
|
|
|
|
n, err := io.Copy(buffer, conn)
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(4), n)
|
|
|
|
require.Equal(t, "PONG", buffer.String())
|
|
|
|
}
|