From a90b2a672e88a643e0bad2aa1f90e4f363289641 Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Thu, 21 Jan 2021 10:04:04 +0100 Subject: [PATCH] perf: improve forwarded header and recovery middlewares Co-authored-by: Ludovic Fernandez --- .../forwardedheaders/forwarded_header.go | 66 +++++++++++++------ pkg/middlewares/recovery/recovery.go | 21 +++--- pkg/middlewares/recovery/recovery_test.go | 2 +- pkg/server/router/router.go | 6 +- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/pkg/middlewares/forwardedheaders/forwarded_header.go b/pkg/middlewares/forwardedheaders/forwarded_header.go index 5917a5c69..5d2355dfd 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header.go @@ -84,19 +84,28 @@ func (x *XForwarded) isTrustedIP(ip string) bool { // removeIPv6Zone removes the zone if the given IP is an ipv6 address and it has {zone} information in it, // like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692". func removeIPv6Zone(clientIP string) string { - return strings.Split(clientIP, "%")[0] + if idx := strings.Index(clientIP, "%"); idx != -1 { + return clientIP[:idx] + } + return clientIP } // isWebsocketRequest returns whether the specified HTTP request is a websocket handshake request. func isWebsocketRequest(req *http.Request) bool { containsHeader := func(name, value string) bool { - items := strings.Split(req.Header.Get(name), ",") - for _, item := range items { - if value == strings.ToLower(strings.TrimSpace(item)) { + h := unsafeHeader(req.Header).Get(name) + for { + pos := strings.Index(h, ",") + if pos == -1 { + return strings.EqualFold(value, strings.TrimSpace(h)) + } + + if strings.EqualFold(value, strings.TrimSpace(h[:pos])) { return true } + + h = h[pos:] } - return false } return containsHeader(connection, "upgrade") && containsHeader(upgrade, "websocket") } @@ -110,7 +119,7 @@ func forwardedPort(req *http.Request) string { return port } - if req.Header.Get(xForwardedProto) == "https" || req.Header.Get(xForwardedProto) == "wss" { + if unsafeHeader(req.Header).Get(xForwardedProto) == "https" || unsafeHeader(req.Header).Get(xForwardedProto) == "wss" { return "443" } @@ -125,38 +134,38 @@ func (x *XForwarded) rewrite(outreq *http.Request) { if clientIP, _, err := net.SplitHostPort(outreq.RemoteAddr); err == nil { clientIP = removeIPv6Zone(clientIP) - if outreq.Header.Get(xRealIP) == "" { - outreq.Header.Set(xRealIP, clientIP) + if unsafeHeader(outreq.Header).Get(xRealIP) == "" { + unsafeHeader(outreq.Header).Set(xRealIP, clientIP) } } - xfProto := outreq.Header.Get(xForwardedProto) + xfProto := unsafeHeader(outreq.Header).Get(xForwardedProto) if xfProto == "" { if isWebsocketRequest(outreq) { if outreq.TLS != nil { - outreq.Header.Set(xForwardedProto, "wss") + unsafeHeader(outreq.Header).Set(xForwardedProto, "wss") } else { - outreq.Header.Set(xForwardedProto, "ws") + unsafeHeader(outreq.Header).Set(xForwardedProto, "ws") } } else { if outreq.TLS != nil { - outreq.Header.Set(xForwardedProto, "https") + unsafeHeader(outreq.Header).Set(xForwardedProto, "https") } else { - outreq.Header.Set(xForwardedProto, "http") + unsafeHeader(outreq.Header).Set(xForwardedProto, "http") } } } - if xfPort := outreq.Header.Get(xForwardedPort); xfPort == "" { - outreq.Header.Set(xForwardedPort, forwardedPort(outreq)) + if xfPort := unsafeHeader(outreq.Header).Get(xForwardedPort); xfPort == "" { + unsafeHeader(outreq.Header).Set(xForwardedPort, forwardedPort(outreq)) } - if xfHost := outreq.Header.Get(xForwardedHost); xfHost == "" && outreq.Host != "" { - outreq.Header.Set(xForwardedHost, outreq.Host) + if xfHost := unsafeHeader(outreq.Header).Get(xForwardedHost); xfHost == "" && outreq.Host != "" { + unsafeHeader(outreq.Header).Set(xForwardedHost, outreq.Host) } if x.hostname != "" { - outreq.Header.Set(xForwardedServer, x.hostname) + unsafeHeader(outreq.Header).Set(xForwardedServer, x.hostname) } } @@ -164,7 +173,7 @@ func (x *XForwarded) rewrite(outreq *http.Request) { func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !x.insecure && !x.isTrustedIP(r.RemoteAddr) { for _, h := range xHeaders { - r.Header.Del(h) + unsafeHeader(r.Header).Del(h) } } @@ -172,3 +181,22 @@ func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) { x.next.ServeHTTP(w, r) } + +// unsafeHeader allows to manage Header values. +// Must be used only when the header name is already a canonical key. +type unsafeHeader map[string][]string + +func (h unsafeHeader) Set(key, value string) { + h[key] = []string{value} +} + +func (h unsafeHeader) Get(key string) string { + if len(h[key]) == 0 { + return "" + } + return h[key][0] +} + +func (h unsafeHeader) Del(key string) { + delete(h, key) +} diff --git a/pkg/middlewares/recovery/recovery.go b/pkg/middlewares/recovery/recovery.go index 7040c91c2..753a5801b 100644 --- a/pkg/middlewares/recovery/recovery.go +++ b/pkg/middlewares/recovery/recovery.go @@ -10,42 +10,41 @@ import ( ) const ( - typeName = "Recovery" + typeName = "Recovery" + middlewareName = "traefik-internal-recovery" ) type recovery struct { next http.Handler - name string } // New creates recovery middleware. -func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) { - log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)).Debug("Creating middleware") +func New(ctx context.Context, next http.Handler) (http.Handler, error) { + log.FromContext(middlewares.GetLoggerCtx(ctx, middlewareName, typeName)).Debug("Creating middleware") return &recovery{ next: next, - name: name, }, nil } func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - defer recoverFunc(middlewares.GetLoggerCtx(req.Context(), re.name, typeName), rw, req) + defer recoverFunc(rw, req) re.next.ServeHTTP(rw, req) } -func recoverFunc(ctx context.Context, rw http.ResponseWriter, r *http.Request) { +func recoverFunc(rw http.ResponseWriter, r *http.Request) { if err := recover(); err != nil { + logger := log.FromContext(middlewares.GetLoggerCtx(r.Context(), middlewareName, typeName)) if !shouldLogPanic(err) { - log.FromContext(ctx).Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err) + logger.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err) return } - log.FromContext(ctx).Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err) - + logger.Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err) const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - log.FromContext(ctx).Errorf("Stack: %s", buf) + logger.Errorf("Stack: %s", buf) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } diff --git a/pkg/middlewares/recovery/recovery_test.go b/pkg/middlewares/recovery/recovery_test.go index 0871f3909..162a9734f 100644 --- a/pkg/middlewares/recovery/recovery_test.go +++ b/pkg/middlewares/recovery/recovery_test.go @@ -14,7 +14,7 @@ func TestRecoverHandler(t *testing.T) { fn := func(w http.ResponseWriter, r *http.Request) { panic("I love panicing!") } - recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery") + recovery, err := New(context.Background(), http.HandlerFunc(fn)) require.NoError(t, err) server := httptest.NewServer(recovery) diff --git a/pkg/server/router/router.go b/pkg/server/router/router.go index 965dd9fa4..195b85159 100644 --- a/pkg/server/router/router.go +++ b/pkg/server/router/router.go @@ -16,10 +16,6 @@ import ( "github.com/traefik/traefik/v2/pkg/server/provider" ) -const ( - recoveryMiddlewareName = "traefik-internal-recovery" -) - type middlewareBuilder interface { BuildChain(ctx context.Context, names []string) *alice.Chain } @@ -130,7 +126,7 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string chain := alice.New() chain = chain.Append(func(next http.Handler) (http.Handler, error) { - return recovery.New(ctx, next, recoveryMiddlewareName) + return recovery.New(ctx, next) }) return chain.Then(router)