Handle shutdown of Hijacked connections

This commit is contained in:
SALLEYRON Julien 2018-07-19 17:30:06 +02:00 committed by Traefiker Bot
parent d50b6a34bc
commit c8ae97fd38
4 changed files with 128 additions and 17 deletions

2
Gopkg.lock generated
View file

@ -1266,7 +1266,7 @@
"roundrobin",
"utils"
]
revision = "f0cbb9d6b797d92d168b95b5c443a31dfa67ccd0"
revision = "a3ed5f65204f4ffccbb56d58cec466cdb7ab730b"
[[projects]]
name = "github.com/vulcand/predicate"

View file

@ -40,6 +40,59 @@ import (
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.Errorf("Error while closing Hijacked conn: %v", err)
}
delete(h.conns, conn)
}
}
// Server is the reverse-proxy/load-balancer engine
type Server struct {
serverEntryPoints serverEntryPoints
@ -80,6 +133,35 @@ type serverEntryPoint struct {
certs *traefiktls.CertificateStore
onDemandListener func(string) (*tls.Certificate, error)
tlsALPNGetter func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
}
func (s serverEntryPoint) Shutdown(ctx context.Context) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait server shutdown is over due to: %s", err)
err = s.httpServer.Close()
if err != nil {
log.Error(err)
}
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait hijack connection is over due to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
wg.Wait()
}
// NewServer returns an initialized Server.
@ -187,13 +269,7 @@ func (s *Server) Stop() {
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil {
log.Debugf("Wait is over due to: %s", err)
err = serverEntryPoint.httpServer.Close()
if err != nil {
log.Error(err)
}
}
serverEntryPoint.Shutdown(ctx)
cancel()
log.Debugf("Entrypoint %s closed", serverEntryPointName)
}(sepn, sep)
@ -447,6 +523,16 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
serverEntryPoint.httpServer = newSrv
serverEntryPoint.listener = listener
serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker()
serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn)
case http.StateClosed:
serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn)
}
}
return serverEntryPoint
}

View file

@ -4,6 +4,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"reflect"
"sort"
@ -245,6 +246,15 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration
forward.Rewriter(rewriter),
forward.ResponseModifier(responseModifier),
forward.BufferPool(s.bufferPool),
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
server := req.Context().Value(http.ServerContextKey).(*http.Server)
if server != nil {
connState := server.ConnState
if connState != nil {
connState(conn, http.StateClosed)
}
}
}),
)
if err != nil {
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)

View file

@ -7,6 +7,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter {
}
}
// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed
func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketConnectionClosedHook = hook
return nil
}
}
// ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
@ -189,6 +198,7 @@ type httpForwarder struct {
log OxyLogger
bufferPool httputil.BufferPool
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
}
const defaultFlushInterval = time.Duration(100) * time.Millisecond
@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
defer underlyingConn.Close()
defer targetConn.Close()
defer func() {
underlyingConn.Close()
targetConn.Close()
if f.websocketConnectionClosedHook != nil {
f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn())
}
}()
errClient := make(chan error, 1)
errBackend := make(chan error, 1)