perf: improve forwarded header and recovery middlewares
Co-authored-by: Ludovic Fernandez <ldez@users.noreply.github.com>
This commit is contained in:
parent
c74918321d
commit
a90b2a672e
4 changed files with 59 additions and 36 deletions
|
@ -84,19 +84,28 @@ func (x *XForwarded) isTrustedIP(ip string) bool {
|
||||||
// removeIPv6Zone removes the zone if the given IP is an ipv6 address and it has {zone} information in it,
|
// removeIPv6Zone removes the zone if the given IP is an ipv6 address and it has {zone} information in it,
|
||||||
// like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692".
|
// like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692".
|
||||||
func removeIPv6Zone(clientIP string) string {
|
func removeIPv6Zone(clientIP string) string {
|
||||||
return strings.Split(clientIP, "%")[0]
|
if idx := strings.Index(clientIP, "%"); idx != -1 {
|
||||||
|
return clientIP[:idx]
|
||||||
|
}
|
||||||
|
return clientIP
|
||||||
}
|
}
|
||||||
|
|
||||||
// isWebsocketRequest returns whether the specified HTTP request is a websocket handshake request.
|
// isWebsocketRequest returns whether the specified HTTP request is a websocket handshake request.
|
||||||
func isWebsocketRequest(req *http.Request) bool {
|
func isWebsocketRequest(req *http.Request) bool {
|
||||||
containsHeader := func(name, value string) bool {
|
containsHeader := func(name, value string) bool {
|
||||||
items := strings.Split(req.Header.Get(name), ",")
|
h := unsafeHeader(req.Header).Get(name)
|
||||||
for _, item := range items {
|
for {
|
||||||
if value == strings.ToLower(strings.TrimSpace(item)) {
|
pos := strings.Index(h, ",")
|
||||||
|
if pos == -1 {
|
||||||
|
return strings.EqualFold(value, strings.TrimSpace(h))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(value, strings.TrimSpace(h[:pos])) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h = h[pos:]
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return containsHeader(connection, "upgrade") && containsHeader(upgrade, "websocket")
|
return containsHeader(connection, "upgrade") && containsHeader(upgrade, "websocket")
|
||||||
}
|
}
|
||||||
|
@ -110,7 +119,7 @@ func forwardedPort(req *http.Request) string {
|
||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Header.Get(xForwardedProto) == "https" || req.Header.Get(xForwardedProto) == "wss" {
|
if unsafeHeader(req.Header).Get(xForwardedProto) == "https" || unsafeHeader(req.Header).Get(xForwardedProto) == "wss" {
|
||||||
return "443"
|
return "443"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,38 +134,38 @@ func (x *XForwarded) rewrite(outreq *http.Request) {
|
||||||
if clientIP, _, err := net.SplitHostPort(outreq.RemoteAddr); err == nil {
|
if clientIP, _, err := net.SplitHostPort(outreq.RemoteAddr); err == nil {
|
||||||
clientIP = removeIPv6Zone(clientIP)
|
clientIP = removeIPv6Zone(clientIP)
|
||||||
|
|
||||||
if outreq.Header.Get(xRealIP) == "" {
|
if unsafeHeader(outreq.Header).Get(xRealIP) == "" {
|
||||||
outreq.Header.Set(xRealIP, clientIP)
|
unsafeHeader(outreq.Header).Set(xRealIP, clientIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
xfProto := outreq.Header.Get(xForwardedProto)
|
xfProto := unsafeHeader(outreq.Header).Get(xForwardedProto)
|
||||||
if xfProto == "" {
|
if xfProto == "" {
|
||||||
if isWebsocketRequest(outreq) {
|
if isWebsocketRequest(outreq) {
|
||||||
if outreq.TLS != nil {
|
if outreq.TLS != nil {
|
||||||
outreq.Header.Set(xForwardedProto, "wss")
|
unsafeHeader(outreq.Header).Set(xForwardedProto, "wss")
|
||||||
} else {
|
} else {
|
||||||
outreq.Header.Set(xForwardedProto, "ws")
|
unsafeHeader(outreq.Header).Set(xForwardedProto, "ws")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if outreq.TLS != nil {
|
if outreq.TLS != nil {
|
||||||
outreq.Header.Set(xForwardedProto, "https")
|
unsafeHeader(outreq.Header).Set(xForwardedProto, "https")
|
||||||
} else {
|
} else {
|
||||||
outreq.Header.Set(xForwardedProto, "http")
|
unsafeHeader(outreq.Header).Set(xForwardedProto, "http")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if xfPort := outreq.Header.Get(xForwardedPort); xfPort == "" {
|
if xfPort := unsafeHeader(outreq.Header).Get(xForwardedPort); xfPort == "" {
|
||||||
outreq.Header.Set(xForwardedPort, forwardedPort(outreq))
|
unsafeHeader(outreq.Header).Set(xForwardedPort, forwardedPort(outreq))
|
||||||
}
|
}
|
||||||
|
|
||||||
if xfHost := outreq.Header.Get(xForwardedHost); xfHost == "" && outreq.Host != "" {
|
if xfHost := unsafeHeader(outreq.Header).Get(xForwardedHost); xfHost == "" && outreq.Host != "" {
|
||||||
outreq.Header.Set(xForwardedHost, outreq.Host)
|
unsafeHeader(outreq.Header).Set(xForwardedHost, outreq.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
if x.hostname != "" {
|
if x.hostname != "" {
|
||||||
outreq.Header.Set(xForwardedServer, x.hostname)
|
unsafeHeader(outreq.Header).Set(xForwardedServer, x.hostname)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,7 +173,7 @@ func (x *XForwarded) rewrite(outreq *http.Request) {
|
||||||
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
|
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
|
||||||
for _, h := range xHeaders {
|
for _, h := range xHeaders {
|
||||||
r.Header.Del(h)
|
unsafeHeader(r.Header).Del(h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,3 +181,22 @@ func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
x.next.ServeHTTP(w, r)
|
x.next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unsafeHeader allows to manage Header values.
|
||||||
|
// Must be used only when the header name is already a canonical key.
|
||||||
|
type unsafeHeader map[string][]string
|
||||||
|
|
||||||
|
func (h unsafeHeader) Set(key, value string) {
|
||||||
|
h[key] = []string{value}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h unsafeHeader) Get(key string) string {
|
||||||
|
if len(h[key]) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return h[key][0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h unsafeHeader) Del(key string) {
|
||||||
|
delete(h, key)
|
||||||
|
}
|
||||||
|
|
|
@ -11,41 +11,40 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
typeName = "Recovery"
|
typeName = "Recovery"
|
||||||
|
middlewareName = "traefik-internal-recovery"
|
||||||
)
|
)
|
||||||
|
|
||||||
type recovery struct {
|
type recovery struct {
|
||||||
next http.Handler
|
next http.Handler
|
||||||
name string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates recovery middleware.
|
// New creates recovery middleware.
|
||||||
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
|
func New(ctx context.Context, next http.Handler) (http.Handler, error) {
|
||||||
log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)).Debug("Creating middleware")
|
log.FromContext(middlewares.GetLoggerCtx(ctx, middlewareName, typeName)).Debug("Creating middleware")
|
||||||
|
|
||||||
return &recovery{
|
return &recovery{
|
||||||
next: next,
|
next: next,
|
||||||
name: name,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
defer recoverFunc(middlewares.GetLoggerCtx(req.Context(), re.name, typeName), rw, req)
|
defer recoverFunc(rw, req)
|
||||||
re.next.ServeHTTP(rw, req)
|
re.next.ServeHTTP(rw, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func recoverFunc(ctx context.Context, rw http.ResponseWriter, r *http.Request) {
|
func recoverFunc(rw http.ResponseWriter, r *http.Request) {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
|
logger := log.FromContext(middlewares.GetLoggerCtx(r.Context(), middlewareName, typeName))
|
||||||
if !shouldLogPanic(err) {
|
if !shouldLogPanic(err) {
|
||||||
log.FromContext(ctx).Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
|
logger.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.FromContext(ctx).Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err)
|
logger.Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err)
|
||||||
|
|
||||||
const size = 64 << 10
|
const size = 64 << 10
|
||||||
buf := make([]byte, size)
|
buf := make([]byte, size)
|
||||||
buf = buf[:runtime.Stack(buf, false)]
|
buf = buf[:runtime.Stack(buf, false)]
|
||||||
log.FromContext(ctx).Errorf("Stack: %s", buf)
|
logger.Errorf("Stack: %s", buf)
|
||||||
|
|
||||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ func TestRecoverHandler(t *testing.T) {
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
panic("I love panicing!")
|
panic("I love panicing!")
|
||||||
}
|
}
|
||||||
recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery")
|
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := httptest.NewServer(recovery)
|
server := httptest.NewServer(recovery)
|
||||||
|
|
|
@ -16,10 +16,6 @@ import (
|
||||||
"github.com/traefik/traefik/v2/pkg/server/provider"
|
"github.com/traefik/traefik/v2/pkg/server/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
recoveryMiddlewareName = "traefik-internal-recovery"
|
|
||||||
)
|
|
||||||
|
|
||||||
type middlewareBuilder interface {
|
type middlewareBuilder interface {
|
||||||
BuildChain(ctx context.Context, names []string) *alice.Chain
|
BuildChain(ctx context.Context, names []string) *alice.Chain
|
||||||
}
|
}
|
||||||
|
@ -130,7 +126,7 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
|
||||||
|
|
||||||
chain := alice.New()
|
chain := alice.New()
|
||||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||||
return recovery.New(ctx, next, recoveryMiddlewareName)
|
return recovery.New(ctx, next)
|
||||||
})
|
})
|
||||||
|
|
||||||
return chain.Then(router)
|
return chain.Then(router)
|
||||||
|
|
Loading…
Reference in a new issue