Handle respondingtimeout and better shutdown tests.
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
This commit is contained in:
parent
0837ec9b70
commit
807dc46ad0
4 changed files with 258 additions and 139 deletions
|
@ -172,7 +172,7 @@ func setupServer(staticConfiguration *static.Configuration) (*server.Server, err
|
||||||
|
|
||||||
acmeProviders := initACMEProvider(staticConfiguration, &providerAggregator, tlsManager)
|
acmeProviders := initACMEProvider(staticConfiguration, &providerAggregator, tlsManager)
|
||||||
|
|
||||||
serverEntryPointsTCP, err := server.NewTCPEntryPoints(*staticConfiguration)
|
serverEntryPointsTCP, err := server.NewTCPEntryPoints(staticConfiguration.EntryPoints)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,9 +52,9 @@ func (h *httpForwarder) Accept() (net.Conn, error) {
|
||||||
type TCPEntryPoints map[string]*TCPEntryPoint
|
type TCPEntryPoints map[string]*TCPEntryPoint
|
||||||
|
|
||||||
// NewTCPEntryPoints creates a new TCPEntryPoints.
|
// NewTCPEntryPoints creates a new TCPEntryPoints.
|
||||||
func NewTCPEntryPoints(staticConfiguration static.Configuration) (TCPEntryPoints, error) {
|
func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, error) {
|
||||||
serverEntryPointsTCP := make(TCPEntryPoints)
|
serverEntryPointsTCP := make(TCPEntryPoints)
|
||||||
for entryPointName, config := range staticConfiguration.EntryPoints {
|
for entryPointName, config := range entryPointsConfig {
|
||||||
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
|
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
@ -171,6 +171,23 @@ func (e *TCPEntryPoint) StartTCP(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
safe.Go(func() {
|
safe.Go(func() {
|
||||||
|
// Enforce read/write deadlines at the connection level,
|
||||||
|
// because when we're peeking the first byte to determine whether we are doing TLS,
|
||||||
|
// the deadlines at the server level are not taken into account.
|
||||||
|
if e.transportConfiguration.RespondingTimeouts.ReadTimeout > 0 {
|
||||||
|
err := writeCloser.SetReadDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.ReadTimeout)))
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error while setting read deadline: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.transportConfiguration.RespondingTimeouts.WriteTimeout > 0 {
|
||||||
|
err = writeCloser.SetWriteDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.WriteTimeout)))
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error while setting write deadline: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker))
|
e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -191,48 +208,48 @@ func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
|
||||||
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
|
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
shutdownServer := func(server stoppableServer) {
|
||||||
|
defer wg.Done()
|
||||||
|
err := server.Shutdown(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
logger.Debugf("Server failed to shutdown within deadline because: %s", err)
|
||||||
|
if err = server.Close(); err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error(err)
|
||||||
|
// We expect Close to fail again because Shutdown most likely failed when trying to close a listener.
|
||||||
|
// We still call it however, to make sure that all connections get closed as well.
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
if e.httpServer.Server != nil {
|
if e.httpServer.Server != nil {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go shutdownServer(e.httpServer.Server)
|
||||||
defer wg.Done()
|
|
||||||
if err := e.httpServer.Server.Shutdown(ctx); err != nil {
|
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
|
||||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
|
||||||
err = e.httpServer.Server.Close()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.httpsServer.Server != nil {
|
if e.httpsServer.Server != nil {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go shutdownServer(e.httpsServer.Server)
|
||||||
defer wg.Done()
|
|
||||||
if err := e.httpsServer.Server.Shutdown(ctx); err != nil {
|
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
|
||||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
|
||||||
err = e.httpsServer.Server.Close()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.tracker != nil {
|
if e.tracker != nil {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := e.tracker.Shutdown(ctx); err != nil {
|
err := e.tracker.Shutdown(ctx)
|
||||||
if ctx.Err() == context.DeadlineExceeded {
|
if err == nil {
|
||||||
logger.Debugf("Wait hijack connection is overdue to: %s", err)
|
return
|
||||||
e.tracker.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
logger.Debugf("Server failed to shutdown before deadline because: %s", err)
|
||||||
|
}
|
||||||
|
e.tracker.Close()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -459,8 +476,11 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
|
||||||
}
|
}
|
||||||
|
|
||||||
serverHTTP := &http.Server{
|
serverHTTP := &http.Server{
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
ErrorLog: httpServerLogger,
|
ErrorLog: httpServerLogger,
|
||||||
|
ReadTimeout: time.Duration(configuration.Transport.RespondingTimeouts.ReadTimeout),
|
||||||
|
WriteTimeout: time.Duration(configuration.Transport.RespondingTimeouts.WriteTimeout),
|
||||||
|
IdleTimeout: time.Duration(configuration.Transport.RespondingTimeouts.IdleTimeout),
|
||||||
}
|
}
|
||||||
|
|
||||||
listener := newHTTPForwarder(ln)
|
listener := newHTTPForwarder(ln)
|
||||||
|
|
|
@ -3,8 +3,11 @@ package server
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,128 +18,206 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShutdownHTTP(t *testing.T) {
|
func TestShutdownHijacked(t *testing.T) {
|
||||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
|
||||||
Address: ":0",
|
|
||||||
Transport: &static.EntryPointsTransport{
|
|
||||||
LifeCycle: &static.LifeCycle{
|
|
||||||
RequestAcceptGraceTimeout: 0,
|
|
||||||
GraceTimeOut: types.Duration(5 * time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
go entryPoint.StartTCP(context.Background())
|
|
||||||
|
|
||||||
router := &tcp.Router{}
|
|
||||||
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
rw.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
entryPoint.SwitchRouter(router)
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
go entryPoint.Shutdown(context.Background())
|
|
||||||
|
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = request.Write(conn)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShutdownHTTPHijacked(t *testing.T) {
|
|
||||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
|
||||||
Address: ":0",
|
|
||||||
Transport: &static.EntryPointsTransport{
|
|
||||||
LifeCycle: &static.LifeCycle{
|
|
||||||
RequestAcceptGraceTimeout: 0,
|
|
||||||
GraceTimeOut: types.Duration(5 * time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
go entryPoint.StartTCP(context.Background())
|
|
||||||
|
|
||||||
router := &tcp.Router{}
|
router := &tcp.Router{}
|
||||||
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
conn, _, err := rw.(http.Hijacker).Hijack()
|
conn, _, err := rw.(http.Hijacker).Hijack()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
|
|
||||||
resp := http.Response{StatusCode: http.StatusOK}
|
resp := http.Response{StatusCode: http.StatusOK}
|
||||||
err = resp.Write(conn)
|
err = resp.Write(conn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}))
|
}))
|
||||||
|
testShutdown(t, router)
|
||||||
entryPoint.SwitchRouter(router)
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
go entryPoint.Shutdown(context.Background())
|
|
||||||
|
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = request.Write(conn)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShutdownTCPConn(t *testing.T) {
|
func TestShutdownHTTP(t *testing.T) {
|
||||||
|
router := &tcp.Router{}
|
||||||
|
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}))
|
||||||
|
testShutdown(t, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownTCP(t *testing.T) {
|
||||||
|
router := &tcp.Router{}
|
||||||
|
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) {
|
||||||
|
for {
|
||||||
|
_, err := http.ReadRequest(bufio.NewReader(conn))
|
||||||
|
|
||||||
|
if err == io.EOF || (err != nil && strings.HasSuffix(err.Error(), "use of closed network connection")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := http.Response{StatusCode: http.StatusOK}
|
||||||
|
err = resp.Write(conn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
testShutdown(t, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testShutdown(t *testing.T, router *tcp.Router) {
|
||||||
|
epConfig := &static.EntryPointsTransport{}
|
||||||
|
epConfig.SetDefaults()
|
||||||
|
|
||||||
|
epConfig.LifeCycle.RequestAcceptGraceTimeout = 0
|
||||||
|
epConfig.LifeCycle.GraceTimeOut = types.Duration(5 * time.Second)
|
||||||
|
|
||||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||||
Address: ":0",
|
// We explicitly use an IPV4 address because on Alpine, with an IPV6 address
|
||||||
Transport: &static.EntryPointsTransport{
|
// there seems to be shenanigans related to properly cleaning up file descriptors
|
||||||
LifeCycle: &static.LifeCycle{
|
Address: "127.0.0.1:0",
|
||||||
RequestAcceptGraceTimeout: 0,
|
Transport: epConfig,
|
||||||
GraceTimeOut: types.Duration(5 * time.Second),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
go entryPoint.StartTCP(context.Background())
|
conn, err := startEntrypoint(entryPoint, router)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
router := &tcp.Router{}
|
epAddr := entryPoint.listener.Addr().String()
|
||||||
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) {
|
|
||||||
_, err := http.ReadRequest(bufio.NewReader(conn))
|
|
||||||
require.NoError(t, err)
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
|
|
||||||
resp := http.Response{StatusCode: http.StatusOK}
|
request, err := http.NewRequest(http.MethodHead, "http://127.0.0.1:8082", nil)
|
||||||
err = resp.Write(conn)
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
|
||||||
}))
|
|
||||||
|
|
||||||
entryPoint.SwitchRouter(router)
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
// We need to do a write on the conn before the shutdown to make it "exist".
|
||||||
|
// Because the connection indeed exists as far as TCP is concerned,
|
||||||
|
// but since we only pass it along to the HTTP server after at least one byte is peaked,
|
||||||
|
// the HTTP server (and hence its shutdown) does not know about the connection until that first byte peaking.
|
||||||
|
err = request.Write(conn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
go entryPoint.Shutdown(context.Background())
|
go entryPoint.Shutdown(context.Background())
|
||||||
|
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
|
// Make sure that new connections are not permitted anymore.
|
||||||
require.NoError(t, err)
|
// Note that this should be true not only after Shutdown has returned,
|
||||||
|
// but technically also as early as the Shutdown has closed the listener,
|
||||||
|
// i.e. during the shutdown and before the gracetime is over.
|
||||||
|
var testOk bool
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
loopConn, err := net.Dial("tcp", epAddr)
|
||||||
|
if err == nil {
|
||||||
|
loopConn.Close()
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(err.Error(), "connection refused") && !strings.HasSuffix(err.Error(), "reset by peer") {
|
||||||
|
t.Fatalf(`unexpected error: got %v, wanted "connection refused" or "reset by peer"`, err)
|
||||||
|
}
|
||||||
|
testOk = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !testOk {
|
||||||
|
t.Fatal("entry point never closed")
|
||||||
|
}
|
||||||
|
|
||||||
err = request.Write(conn)
|
// And make sure that the connection we had opened before shutting things down is still operational
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startEntrypoint(entryPoint *TCPEntryPoint, router *tcp.Router) (net.Conn, error) {
|
||||||
|
go entryPoint.StartTCP(context.Background())
|
||||||
|
|
||||||
|
entryPoint.SwitchRouter(router)
|
||||||
|
|
||||||
|
var conn net.Conn
|
||||||
|
var err error
|
||||||
|
var epStarted bool
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
conn, err = net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
epStarted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !epStarted {
|
||||||
|
return nil, errors.New("entry point never started")
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadTimeoutWithoutFirstByte(t *testing.T) {
|
||||||
|
epConfig := &static.EntryPointsTransport{}
|
||||||
|
epConfig.SetDefaults()
|
||||||
|
epConfig.RespondingTimeouts.ReadTimeout = types.Duration(time.Second * 2)
|
||||||
|
|
||||||
|
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||||
|
Address: ":0",
|
||||||
|
Transport: epConfig,
|
||||||
|
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
router := &tcp.Router{}
|
||||||
|
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
conn, err := startEntrypoint(entryPoint, router)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errChan := make(chan error)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
b := make([]byte, 2048)
|
||||||
|
_, err := conn.Read(b)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
require.Equal(t, io.EOF, err)
|
||||||
|
case <-time.Tick(time.Second * 5):
|
||||||
|
t.Error("Timeout while read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadTimeoutWithFirstByte(t *testing.T) {
|
||||||
|
epConfig := &static.EntryPointsTransport{}
|
||||||
|
epConfig.SetDefaults()
|
||||||
|
epConfig.RespondingTimeouts.ReadTimeout = types.Duration(time.Second * 2)
|
||||||
|
|
||||||
|
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||||
|
Address: ":0",
|
||||||
|
Transport: epConfig,
|
||||||
|
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
router := &tcp.Router{}
|
||||||
|
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
conn, err := startEntrypoint(entryPoint, router)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte("GET /some HTTP/1.1\r\n"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errChan := make(chan error)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
b := make([]byte, 2048)
|
||||||
|
_, err := conn.Read(b)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
require.Equal(t, io.EOF, err)
|
||||||
|
case <-time.Tick(time.Second * 5):
|
||||||
|
t.Error("Timeout while read")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/containous/traefik/v2/pkg/log"
|
"github.com/containous/traefik/v2/pkg/log"
|
||||||
)
|
)
|
||||||
|
@ -34,7 +35,23 @@ func (r *Router) ServeTCP(conn WriteCloser) {
|
||||||
}
|
}
|
||||||
|
|
||||||
br := bufio.NewReader(conn)
|
br := bufio.NewReader(conn)
|
||||||
serverName, tls, peeked := clientHelloServerName(br)
|
serverName, tls, peeked, err := clientHelloServerName(br)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove read/write deadline and delegate this to underlying tcp server (for now only handled by HTTP Server)
|
||||||
|
err = conn.SetReadDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
log.WithoutContext().Errorf("Error while setting read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.SetWriteDeadline(time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
log.WithoutContext().Errorf("Error while setting write deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if !tls {
|
if !tls {
|
||||||
switch {
|
switch {
|
||||||
case r.catchAllNoTLS != nil:
|
case r.catchAllNoTLS != nil:
|
||||||
|
@ -176,33 +193,34 @@ func (c *Conn) Read(p []byte) (n int, err error) {
|
||||||
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
|
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
|
||||||
// without consuming any bytes from br.
|
// without consuming any bytes from br.
|
||||||
// On any error, the empty string is returned.
|
// On any error, the empty string is returned.
|
||||||
func clientHelloServerName(br *bufio.Reader) (string, bool, string) {
|
func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) {
|
||||||
hdr, err := br.Peek(1)
|
hdr, err := br.Peek(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != io.EOF {
|
opErr, ok := err.(*net.OpError)
|
||||||
log.Errorf("Error while Peeking first byte: %s", err)
|
if err != io.EOF && (!ok || !opErr.Timeout()) {
|
||||||
|
log.WithoutContext().Errorf("Error while Peeking first byte: %s", err)
|
||||||
}
|
}
|
||||||
return "", false, ""
|
return "", false, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
const recordTypeHandshake = 0x16
|
const recordTypeHandshake = 0x16
|
||||||
if hdr[0] != recordTypeHandshake {
|
if hdr[0] != recordTypeHandshake {
|
||||||
// log.Errorf("Error not tls")
|
// log.Errorf("Error not tls")
|
||||||
return "", false, getPeeked(br) // Not TLS.
|
return "", false, getPeeked(br), nil // Not TLS.
|
||||||
}
|
}
|
||||||
|
|
||||||
const recordHeaderLen = 5
|
const recordHeaderLen = 5
|
||||||
hdr, err = br.Peek(recordHeaderLen)
|
hdr, err = br.Peek(recordHeaderLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error while Peeking hello: %s", err)
|
log.Errorf("Error while Peeking hello: %s", err)
|
||||||
return "", false, getPeeked(br)
|
return "", false, getPeeked(br), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
|
recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
|
||||||
helloBytes, err := br.Peek(recordHeaderLen + recLen)
|
helloBytes, err := br.Peek(recordHeaderLen + recLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error while Hello: %s", err)
|
log.Errorf("Error while Hello: %s", err)
|
||||||
return "", true, getPeeked(br)
|
return "", true, getPeeked(br), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sni := ""
|
sni := ""
|
||||||
|
@ -214,7 +232,7 @@ func clientHelloServerName(br *bufio.Reader) (string, bool, string) {
|
||||||
})
|
})
|
||||||
_ = server.Handshake()
|
_ = server.Handshake()
|
||||||
|
|
||||||
return sni, true, getPeeked(br)
|
return sni, true, getPeeked(br), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeeked(br *bufio.Reader) string {
|
func getPeeked(br *bufio.Reader) string {
|
||||||
|
|
Loading…
Reference in a new issue