Merge pull request #1598 from containous/fix-stats-hijack

Fix stats hijack
This commit is contained in:
Emile Vauge 2017-05-15 15:04:23 +02:00 committed by GitHub
commit e3ab4e4d63
6 changed files with 115 additions and 5 deletions

33
middlewares/recover.go Normal file
View file

@ -0,0 +1,33 @@
package middlewares
import (
"net/http"
"github.com/codegangsta/negroni"
"github.com/containous/traefik/log"
)
// RecoverHandler recovers from a panic in http handlers
func RecoverHandler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer recoverFunc(w)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// NegroniRecoverHandler recovers from a panic in negroni handlers
func NegroniRecoverHandler() negroni.Handler {
fn := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
defer recoverFunc(w)
next.ServeHTTP(w, r)
}
return negroni.HandlerFunc(fn)
}
func recoverFunc(w http.ResponseWriter) {
if err := recover(); err != nil {
log.Errorf("Recovered from panic in http handler: %+v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}

View file

@ -0,0 +1,45 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/codegangsta/negroni"
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicing!")
}
recoverHandler := RecoverHandler(http.HandlerFunc(fn))
server := httptest.NewServer(recoverHandler)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
}
}
func TestNegroniRecoverHandler(t *testing.T) {
n := negroni.New()
n.Use(NegroniRecoverHandler())
panicHandler := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
panic("I love panicing!")
}
n.UseFunc(negroni.HandlerFunc(panicHandler))
server := httptest.NewServer(n)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
}
}

View file

@ -12,10 +12,7 @@ import (
) )
var ( var (
_ http.ResponseWriter = &ResponseRecorder{} _ Stateful = &ResponseRecorder{}
_ http.Hijacker = &ResponseRecorder{}
_ http.Flusher = &ResponseRecorder{}
_ http.CloseNotifier = &ResponseRecorder{}
) )
// Retry is a middleware that retries requests // Retry is a middleware that retries requests

12
middlewares/stateful.go Normal file
View file

@ -0,0 +1,12 @@
package middlewares
import "net/http"
// Stateful interface groups all http interfaces that must be
// implemented by a stateful middleware (ie: recorders)
type Stateful interface {
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
}

View file

@ -1,11 +1,17 @@
package middlewares package middlewares
import ( import (
"bufio"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
) )
var (
_ Stateful = &responseRecorder{}
)
// StatsRecorder is an optional middleware that records more details statistics // StatsRecorder is an optional middleware that records more details statistics
// about requests and how they are processed. This currently consists of recent // about requests and how they are processed. This currently consists of recent
// requests that have caused errors (4xx and 5xx status codes), making it easy // requests that have caused errors (4xx and 5xx status codes), making it easy
@ -51,6 +57,23 @@ func (r *responseRecorder) WriteHeader(status int) {
r.statusCode = status r.statusCode = status
} }
// Hijack hijacks the connection
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.ResponseWriter.(http.Hijacker).Hijack()
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone
// away.
func (r *responseRecorder) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush sends any buffered data to the client.
func (r *responseRecorder) Flush() {
r.ResponseWriter.(http.Flusher).Flush()
}
// ServeHTTP silently extracts information from the request and response as it // ServeHTTP silently extracts information from the request and response as it
// is processed. If the response is 4xx or 5xx, add it to the list of 10 most // is processed. If the response is 4xx or 5xx, add it to the list of 10 most
// recent errors. // recent errors.

View file

@ -173,7 +173,7 @@ func (server *Server) startHTTPServers() {
server.serverEntryPoints = server.buildEntryPoints(server.globalConfiguration) server.serverEntryPoints = server.buildEntryPoints(server.globalConfiguration)
for newServerEntryPointName, newServerEntryPoint := range server.serverEntryPoints { for newServerEntryPointName, newServerEntryPoint := range server.serverEntryPoints {
serverMiddlewares := []negroni.Handler{server.loggerMiddleware, metrics} serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler(), server.loggerMiddleware, metrics}
if server.globalConfiguration.Web != nil && server.globalConfiguration.Web.Metrics != nil { if server.globalConfiguration.Web != nil && server.globalConfiguration.Web.Metrics != nil {
if server.globalConfiguration.Web.Metrics.Prometheus != nil { if server.globalConfiguration.Web.Metrics.Prometheus != nil {
metricsMiddleware := middlewares.NewMetricsWrapper(middlewares.NewPrometheus(newServerEntryPointName, server.globalConfiguration.Web.Metrics.Prometheus)) metricsMiddleware := middlewares.NewMetricsWrapper(middlewares.NewPrometheus(newServerEntryPointName, server.globalConfiguration.Web.Metrics.Prometheus))