From f79317a435bf5d0252a1e5aa9833413efa380bab Mon Sep 17 00:00:00 2001 From: Marco Jantke Date: Wed, 3 May 2017 10:20:33 +0200 Subject: [PATCH] retry only on real network errors Now retries only happen when actual network errors occur and not only anymore based on the HTTP status code. This is because the backend could also send this status codes as their normal interface and in that case we don't want to retry. --- middlewares/prometheus_test.go | 2 +- middlewares/retry.go | 36 ++++++++++++-- middlewares/retry_test.go | 37 ++++++++++++-- server/errorhandler.go | 40 ++++++++++++++++ server/errorhandler_test.go | 88 ++++++++++++++++++++++++++++++++++ server/server.go | 8 +++- 6 files changed, 202 insertions(+), 9 deletions(-) create mode 100644 server/errorhandler.go create mode 100644 server/errorhandler_test.go diff --git a/middlewares/prometheus_test.go b/middlewares/prometheus_test.go index d85547c2b..4d2ef9223 100644 --- a/middlewares/prometheus_test.go +++ b/middlewares/prometheus_test.go @@ -151,7 +151,7 @@ func TestPrometheusRegisterMetricsMultipleTimes(t *testing.T) { func setupTestHTTPHandler() http.Handler { serveMux := http.NewServeMux() serveMux.Handle("/metrics", promhttp.Handler()) - serveMux.Handle("/ok", &networkFailingHTTPHandler{failAtCalls: []int{1}}) + serveMux.Handle("/ok", &networkFailingHTTPHandler{failAtCalls: []int{1}, netErrorRecorder: &DefaultNetErrorRecorder{}}) metrics, _ := newPrometheusMetrics() diff --git a/middlewares/retry.go b/middlewares/retry.go index 41c8a225f..4f2720173 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -3,6 +3,7 @@ package middlewares import ( "bufio" "bytes" + "context" "io/ioutil" "net" "net/http" @@ -42,10 +43,16 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } attempts := 1 for { + netErrorOccurred := false + // We pass in a pointer to netErrorOccurred so that we can set it to true on network errors + // when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler. + newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred) + recorder := newRetryResponseRecorder() recorder.responseWriter = rw - retry.next.ServeHTTP(recorder, r) - if !isNetworkError(recorder.Code) || attempts >= retry.attempts { + + retry.next.ServeHTTP(recorder, r.WithContext(newCtx)) + if !netErrorOccurred || attempts >= retry.attempts { utils.CopyHeaders(rw.Header(), recorder.Header()) rw.WriteHeader(recorder.Code) rw.Write(recorder.Body.Bytes()) @@ -57,8 +64,29 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } } -func isNetworkError(status int) bool { - return status == http.StatusBadGateway || status == http.StatusGatewayTimeout +// netErrorCtxKey is a custom type that is used as key for the context. +type netErrorCtxKey string + +// defaultNetErrCtxKey is the actual key which value is used to record network errors. +var defaultNetErrCtxKey netErrorCtxKey = "NetErrCtxKey" + +// NetErrorRecorder is an interface to record net errors. +type NetErrorRecorder interface { + // Record can be used to signal the retry middleware that an network error happened + // and therefore the request should be retried. + Record(ctx context.Context) +} + +// DefaultNetErrorRecorder is the default NetErrorRecorder implementation. +type DefaultNetErrorRecorder struct{} + +// Record is recording network errors by setting the context value for the defaultNetErrCtxKey to true. +func (DefaultNetErrorRecorder) Record(ctx context.Context) { + val := ctx.Value(defaultNetErrCtxKey) + + if netErrorOccurred, isBoolPointer := val.(*bool); isBoolPointer { + *netErrorOccurred = true + } } // RetryListener is used to inform about retry attempts. diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go index cd161dc44..bee18f378 100644 --- a/middlewares/retry_test.go +++ b/middlewares/retry_test.go @@ -1,6 +1,7 @@ package middlewares import ( + "context" "fmt" "io/ioutil" "net/http" @@ -41,7 +42,7 @@ func TestRetry(t *testing.T) { t.Parallel() var httpHandler http.Handler - httpHandler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls} + httpHandler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls, netErrorRecorder: &DefaultNetErrorRecorder{}} httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener) recorder := httptest.NewRecorder() @@ -62,10 +63,38 @@ func TestRetry(t *testing.T) { } } +func TestDefaultNetErrorRecorderSuccess(t *testing.T) { + boolNetErrorOccurred := false + recorder := DefaultNetErrorRecorder{} + recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &boolNetErrorOccurred)) + if !boolNetErrorOccurred { + t.Errorf("got %v after recording net error, wanted %v", boolNetErrorOccurred, true) + } +} + +func TestDefaultNetErrorRecorderInvalidValueType(t *testing.T) { + stringNetErrorOccured := "nonsense" + recorder := DefaultNetErrorRecorder{} + recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &stringNetErrorOccured)) + if stringNetErrorOccured != "nonsense" { + t.Errorf("got %v after recording net error, wanted %v", stringNetErrorOccured, "nonsense") + } +} + +func TestDefaultNetErrorRecorderNilValue(t *testing.T) { + nilNetErrorOccured := interface{}(nil) + recorder := DefaultNetErrorRecorder{} + recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &nilNetErrorOccured)) + if nilNetErrorOccured != interface{}(nil) { + t.Errorf("got %v after recording net error, wanted %v", nilNetErrorOccured, interface{}(nil)) + } +} + // networkFailingHTTPHandler is an http.Handler implementation you can use to test retries. type networkFailingHTTPHandler struct { - failAtCalls []int - callNumber int + netErrorRecorder NetErrorRecorder + failAtCalls []int + callNumber int } func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -73,6 +102,8 @@ func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *ht for _, failAtCall := range handler.failAtCalls { if handler.callNumber == failAtCall { + handler.netErrorRecorder.Record(r.Context()) + w.WriteHeader(http.StatusBadGateway) return } diff --git a/server/errorhandler.go b/server/errorhandler.go new file mode 100644 index 000000000..80cc9fae6 --- /dev/null +++ b/server/errorhandler.go @@ -0,0 +1,40 @@ +package server + +import ( + "io" + "net" + "net/http" + + "github.com/containous/traefik/middlewares" +) + +// RecordingErrorHandler is an error handler, implementing the vulcand/oxy +// error handler interface, which is recording network errors by using the netErrorRecorder. +// In addition it sets a proper HTTP status code and body, depending on the type of error occurred. +type RecordingErrorHandler struct { + netErrorRecorder middlewares.NetErrorRecorder +} + +// NewRecordingErrorHandler creates and returns a new instance of RecordingErrorHandler. +func NewRecordingErrorHandler(recorder middlewares.NetErrorRecorder) *RecordingErrorHandler { + return &RecordingErrorHandler{recorder} +} + +func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { + statusCode := http.StatusInternalServerError + + if e, ok := err.(net.Error); ok { + eh.netErrorRecorder.Record(req.Context()) + if e.Timeout() { + statusCode = http.StatusGatewayTimeout + } else { + statusCode = http.StatusBadGateway + } + } else if err == io.EOF { + eh.netErrorRecorder.Record(req.Context()) + statusCode = http.StatusBadGateway + } + + w.WriteHeader(statusCode) + w.Write([]byte(http.StatusText(statusCode))) +} diff --git a/server/errorhandler_test.go b/server/errorhandler_test.go new file mode 100644 index 000000000..0ff0a0255 --- /dev/null +++ b/server/errorhandler_test.go @@ -0,0 +1,88 @@ +package server + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +func TestServeHTTP(t *testing.T) { + tests := []struct { + name string + err error + wantHTTPStatus int + wantNetErrRecorded bool + }{ + { + name: "net.Error", + err: net.UnknownNetworkError("any network error"), + wantHTTPStatus: http.StatusBadGateway, + wantNetErrRecorded: true, + }, + { + name: "net.Error with Timeout", + err: &timeoutError{}, + wantHTTPStatus: http.StatusGatewayTimeout, + wantNetErrRecorded: true, + }, + { + name: "io.EOF", + err: io.EOF, + wantHTTPStatus: http.StatusBadGateway, + wantNetErrRecorded: true, + }, + { + name: "custom error", + err: errors.New("any error"), + wantHTTPStatus: http.StatusInternalServerError, + wantNetErrRecorded: false, + }, + { + name: "nil error", + err: nil, + wantHTTPStatus: http.StatusInternalServerError, + wantNetErrRecorded: false, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + + errorRecorder := &netErrorRecorder{} + req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/any", nil) + + recordingErrorHandler := NewRecordingErrorHandler(errorRecorder) + recordingErrorHandler.ServeHTTP(recorder, req, test.err) + + if recorder.Code != test.wantHTTPStatus { + t.Errorf("got HTTP status code %v, wanted %v", recorder.Code, test.wantHTTPStatus) + } + if errorRecorder.netErrorWasRecorded != test.wantNetErrRecorded { + t.Errorf("net error recording wrong, got %v wanted %v", errorRecorder.netErrorWasRecorded, test.wantNetErrRecorded) + } + }) + } +} + +type netErrorRecorder struct { + netErrorWasRecorded bool +} + +func (recorder *netErrorRecorder) Record(ctx context.Context) { + recorder.netErrorWasRecorded = true +} diff --git a/server/server.go b/server/server.go index 7c117114c..1c02958bb 100644 --- a/server/server.go +++ b/server/server.go @@ -595,6 +595,7 @@ func (server *Server) loadConfig(configurations configs, globalConfiguration Glo redirectHandlers := make(map[string]negroni.Handler) backends := map[string]http.Handler{} backendsHealthcheck := map[string]*healthcheck.BackendHealthCheck{} + errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{}) for _, configuration := range configurations { frontendNames := sortedFrontendNamesForConfig(configuration) @@ -669,7 +670,12 @@ func (server *Server) loadConfig(configurations configs, globalConfiguration Glo // passing nil will use the roundtripper http.DefaultTransport rt := clientTLSRoundTripper(tlsConfig) - fwd, err := forward.New(forward.Logger(oxyLogger), forward.PassHostHeader(frontend.PassHostHeader), forward.RoundTripper(rt)) + fwd, err := forward.New( + forward.Logger(oxyLogger), + forward.PassHostHeader(frontend.PassHostHeader), + forward.RoundTripper(rt), + forward.ErrorHandler(errorHandler), + ) if err != nil { log.Errorf("Error creating forwarder for frontend %s: %v", frontendName, err) log.Errorf("Skipping frontend %s...", frontendName)