diff --git a/Gopkg.lock b/Gopkg.lock index c2a3bbbc8..f6718f283 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -1266,7 +1266,7 @@ "roundrobin", "utils" ] - revision = "f0cbb9d6b797d92d168b95b5c443a31dfa67ccd0" + revision = "a3ed5f65204f4ffccbb56d58cec466cdb7ab730b" [[projects]] name = "github.com/vulcand/predicate" diff --git a/server/server.go b/server/server.go index 0869b1300..2739aad5a 100644 --- a/server/server.go +++ b/server/server.go @@ -40,6 +40,59 @@ import ( var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0) +func newHijackConnectionTracker() *hijackConnectionTracker { + return &hijackConnectionTracker{ + conns: make(map[net.Conn]struct{}), + } +} + +type hijackConnectionTracker struct { + conns map[net.Conn]struct{} + lock sync.RWMutex +} + +// AddHijackedConnection add a connection in the tracked connections list +func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) { + h.lock.Lock() + defer h.lock.Unlock() + h.conns[conn] = struct{}{} +} + +// RemoveHijackedConnection remove a connection from the tracked connections list +func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) { + h.lock.Lock() + defer h.lock.Unlock() + delete(h.conns, conn) +} + +// Shutdown wait for the connection closing +func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + h.lock.RLock() + if len(h.conns) == 0 { + return nil + } + h.lock.RUnlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Close close all the connections in the tracked connections list +func (h *hijackConnectionTracker) Close() { + for conn := range h.conns { + if err := conn.Close(); err != nil { + log.Errorf("Error while closing Hijacked conn: %v", err) + } + delete(h.conns, conn) + } +} + // Server is the reverse-proxy/load-balancer engine type Server struct { serverEntryPoints serverEntryPoints @@ -74,12 +127,41 @@ type EntryPoint struct { type serverEntryPoints map[string]*serverEntryPoint type serverEntryPoint struct { - httpServer *h2c.Server - listener net.Listener - httpRouter *middlewares.HandlerSwitcher - certs *traefiktls.CertificateStore - onDemandListener func(string) (*tls.Certificate, error) - tlsALPNGetter func(string) (*tls.Certificate, error) + httpServer *h2c.Server + listener net.Listener + httpRouter *middlewares.HandlerSwitcher + certs *traefiktls.CertificateStore + onDemandListener func(string) (*tls.Certificate, error) + tlsALPNGetter func(string) (*tls.Certificate, error) + hijackConnectionTracker *hijackConnectionTracker +} + +func (s serverEntryPoint) Shutdown(ctx context.Context) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := s.httpServer.Shutdown(ctx); err != nil { + if ctx.Err() == context.DeadlineExceeded { + log.Debugf("Wait server shutdown is over due to: %s", err) + err = s.httpServer.Close() + if err != nil { + log.Error(err) + } + } + } + }() + wg.Add(1) + go func() { + defer wg.Done() + if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil { + if ctx.Err() == context.DeadlineExceeded { + log.Debugf("Wait hijack connection is over due to: %s", err) + s.hijackConnectionTracker.Close() + } + } + }() + wg.Wait() } // NewServer returns an initialized Server. @@ -187,13 +269,7 @@ func (s *Server) Stop() { graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut) ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut) log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName) - if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil { - log.Debugf("Wait is over due to: %s", err) - err = serverEntryPoint.httpServer.Close() - if err != nil { - log.Error(err) - } - } + serverEntryPoint.Shutdown(ctx) cancel() log.Debugf("Entrypoint %s closed", serverEntryPointName) }(sepn, sep) @@ -447,6 +523,16 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer serverEntryPoint.httpServer = newSrv serverEntryPoint.listener = listener + serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker() + serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateHijacked: + serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn) + case http.StateClosed: + serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn) + } + } + return serverEntryPoint } diff --git a/server/server_configuration.go b/server/server_configuration.go index 5b4030b65..41ef7ab2c 100644 --- a/server/server_configuration.go +++ b/server/server_configuration.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net" "net/http" "reflect" "sort" @@ -245,6 +246,15 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration forward.Rewriter(rewriter), forward.ResponseModifier(responseModifier), forward.BufferPool(s.bufferPool), + forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) { + server := req.Context().Value(http.ServerContextKey).(*http.Server) + if server != nil { + connState := server.ConnState + if connState != nil { + connState(conn, http.StateClosed) + } + } + }), ) if err != nil { return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err) diff --git a/vendor/github.com/vulcand/oxy/forward/fwd.go b/vendor/github.com/vulcand/oxy/forward/fwd.go index ec4bea59f..337d5eff5 100644 --- a/vendor/github.com/vulcand/oxy/forward/fwd.go +++ b/vendor/github.com/vulcand/oxy/forward/fwd.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter { } } +// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed +func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketConnectionClosedHook = hook + return nil + } +} + // ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { @@ -188,7 +197,8 @@ type httpForwarder struct { log OxyLogger - bufferPool httputil.BufferPool + bufferPool httputil.BufferPool + websocketConnectionClosedHook func(req *http.Request, conn net.Conn) } const defaultFlushInterval = time.Duration(100) * time.Millisecond @@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } - defer underlyingConn.Close() - defer targetConn.Close() + defer func() { + underlyingConn.Close() + targetConn.Close() + if f.websocketConnectionClosedHook != nil { + f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn()) + } + }() errClient := make(chan error, 1) errBackend := make(chan error, 1)