diff --git a/middlewares/prometheus_test.go b/middlewares/prometheus_test.go index d85547c2b..4d2ef9223 100644 --- a/middlewares/prometheus_test.go +++ b/middlewares/prometheus_test.go @@ -151,7 +151,7 @@ func TestPrometheusRegisterMetricsMultipleTimes(t *testing.T) { func setupTestHTTPHandler() http.Handler { serveMux := http.NewServeMux() serveMux.Handle("/metrics", promhttp.Handler()) - serveMux.Handle("/ok", &networkFailingHTTPHandler{failAtCalls: []int{1}}) + serveMux.Handle("/ok", &networkFailingHTTPHandler{failAtCalls: []int{1}, netErrorRecorder: &DefaultNetErrorRecorder{}}) metrics, _ := newPrometheusMetrics() diff --git a/middlewares/retry.go b/middlewares/retry.go index 41c8a225f..4f2720173 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -3,6 +3,7 @@ package middlewares import ( "bufio" "bytes" + "context" "io/ioutil" "net" "net/http" @@ -42,10 +43,16 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } attempts := 1 for { + netErrorOccurred := false + // We pass in a pointer to netErrorOccurred so that we can set it to true on network errors + // when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler. + newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred) + recorder := newRetryResponseRecorder() recorder.responseWriter = rw - retry.next.ServeHTTP(recorder, r) - if !isNetworkError(recorder.Code) || attempts >= retry.attempts { + + retry.next.ServeHTTP(recorder, r.WithContext(newCtx)) + if !netErrorOccurred || attempts >= retry.attempts { utils.CopyHeaders(rw.Header(), recorder.Header()) rw.WriteHeader(recorder.Code) rw.Write(recorder.Body.Bytes()) @@ -57,8 +64,29 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } } -func isNetworkError(status int) bool { - return status == http.StatusBadGateway || status == http.StatusGatewayTimeout +// netErrorCtxKey is a custom type that is used as key for the context. +type netErrorCtxKey string + +// defaultNetErrCtxKey is the actual key which value is used to record network errors. +var defaultNetErrCtxKey netErrorCtxKey = "NetErrCtxKey" + +// NetErrorRecorder is an interface to record net errors. +type NetErrorRecorder interface { + // Record can be used to signal the retry middleware that an network error happened + // and therefore the request should be retried. + Record(ctx context.Context) +} + +// DefaultNetErrorRecorder is the default NetErrorRecorder implementation. +type DefaultNetErrorRecorder struct{} + +// Record is recording network errors by setting the context value for the defaultNetErrCtxKey to true. +func (DefaultNetErrorRecorder) Record(ctx context.Context) { + val := ctx.Value(defaultNetErrCtxKey) + + if netErrorOccurred, isBoolPointer := val.(*bool); isBoolPointer { + *netErrorOccurred = true + } } // RetryListener is used to inform about retry attempts. diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go index cd161dc44..bee18f378 100644 --- a/middlewares/retry_test.go +++ b/middlewares/retry_test.go @@ -1,6 +1,7 @@ package middlewares import ( + "context" "fmt" "io/ioutil" "net/http" @@ -41,7 +42,7 @@ func TestRetry(t *testing.T) { t.Parallel() var httpHandler http.Handler - httpHandler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls} + httpHandler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls, netErrorRecorder: &DefaultNetErrorRecorder{}} httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener) recorder := httptest.NewRecorder() @@ -62,10 +63,38 @@ func TestRetry(t *testing.T) { } } +func TestDefaultNetErrorRecorderSuccess(t *testing.T) { + boolNetErrorOccurred := false + recorder := DefaultNetErrorRecorder{} + recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &boolNetErrorOccurred)) + if !boolNetErrorOccurred { + t.Errorf("got %v after recording net error, wanted %v", boolNetErrorOccurred, true) + } +} + +func TestDefaultNetErrorRecorderInvalidValueType(t *testing.T) { + stringNetErrorOccured := "nonsense" + recorder := DefaultNetErrorRecorder{} + recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &stringNetErrorOccured)) + if stringNetErrorOccured != "nonsense" { + t.Errorf("got %v after recording net error, wanted %v", stringNetErrorOccured, "nonsense") + } +} + +func TestDefaultNetErrorRecorderNilValue(t *testing.T) { + nilNetErrorOccured := interface{}(nil) + recorder := DefaultNetErrorRecorder{} + recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &nilNetErrorOccured)) + if nilNetErrorOccured != interface{}(nil) { + t.Errorf("got %v after recording net error, wanted %v", nilNetErrorOccured, interface{}(nil)) + } +} + // networkFailingHTTPHandler is an http.Handler implementation you can use to test retries. type networkFailingHTTPHandler struct { - failAtCalls []int - callNumber int + netErrorRecorder NetErrorRecorder + failAtCalls []int + callNumber int } func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -73,6 +102,8 @@ func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *ht for _, failAtCall := range handler.failAtCalls { if handler.callNumber == failAtCall { + handler.netErrorRecorder.Record(r.Context()) + w.WriteHeader(http.StatusBadGateway) return } diff --git a/server/errorhandler.go b/server/errorhandler.go new file mode 100644 index 000000000..80cc9fae6 --- /dev/null +++ b/server/errorhandler.go @@ -0,0 +1,40 @@ +package server + +import ( + "io" + "net" + "net/http" + + "github.com/containous/traefik/middlewares" +) + +// RecordingErrorHandler is an error handler, implementing the vulcand/oxy +// error handler interface, which is recording network errors by using the netErrorRecorder. +// In addition it sets a proper HTTP status code and body, depending on the type of error occurred. +type RecordingErrorHandler struct { + netErrorRecorder middlewares.NetErrorRecorder +} + +// NewRecordingErrorHandler creates and returns a new instance of RecordingErrorHandler. +func NewRecordingErrorHandler(recorder middlewares.NetErrorRecorder) *RecordingErrorHandler { + return &RecordingErrorHandler{recorder} +} + +func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { + statusCode := http.StatusInternalServerError + + if e, ok := err.(net.Error); ok { + eh.netErrorRecorder.Record(req.Context()) + if e.Timeout() { + statusCode = http.StatusGatewayTimeout + } else { + statusCode = http.StatusBadGateway + } + } else if err == io.EOF { + eh.netErrorRecorder.Record(req.Context()) + statusCode = http.StatusBadGateway + } + + w.WriteHeader(statusCode) + w.Write([]byte(http.StatusText(statusCode))) +} diff --git a/server/errorhandler_test.go b/server/errorhandler_test.go new file mode 100644 index 000000000..0ff0a0255 --- /dev/null +++ b/server/errorhandler_test.go @@ -0,0 +1,88 @@ +package server + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +func TestServeHTTP(t *testing.T) { + tests := []struct { + name string + err error + wantHTTPStatus int + wantNetErrRecorded bool + }{ + { + name: "net.Error", + err: net.UnknownNetworkError("any network error"), + wantHTTPStatus: http.StatusBadGateway, + wantNetErrRecorded: true, + }, + { + name: "net.Error with Timeout", + err: &timeoutError{}, + wantHTTPStatus: http.StatusGatewayTimeout, + wantNetErrRecorded: true, + }, + { + name: "io.EOF", + err: io.EOF, + wantHTTPStatus: http.StatusBadGateway, + wantNetErrRecorded: true, + }, + { + name: "custom error", + err: errors.New("any error"), + wantHTTPStatus: http.StatusInternalServerError, + wantNetErrRecorded: false, + }, + { + name: "nil error", + err: nil, + wantHTTPStatus: http.StatusInternalServerError, + wantNetErrRecorded: false, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + + errorRecorder := &netErrorRecorder{} + req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/any", nil) + + recordingErrorHandler := NewRecordingErrorHandler(errorRecorder) + recordingErrorHandler.ServeHTTP(recorder, req, test.err) + + if recorder.Code != test.wantHTTPStatus { + t.Errorf("got HTTP status code %v, wanted %v", recorder.Code, test.wantHTTPStatus) + } + if errorRecorder.netErrorWasRecorded != test.wantNetErrRecorded { + t.Errorf("net error recording wrong, got %v wanted %v", errorRecorder.netErrorWasRecorded, test.wantNetErrRecorded) + } + }) + } +} + +type netErrorRecorder struct { + netErrorWasRecorded bool +} + +func (recorder *netErrorRecorder) Record(ctx context.Context) { + recorder.netErrorWasRecorded = true +} diff --git a/server/server.go b/server/server.go index 7c117114c..1c02958bb 100644 --- a/server/server.go +++ b/server/server.go @@ -595,6 +595,7 @@ func (server *Server) loadConfig(configurations configs, globalConfiguration Glo redirectHandlers := make(map[string]negroni.Handler) backends := map[string]http.Handler{} backendsHealthcheck := map[string]*healthcheck.BackendHealthCheck{} + errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{}) for _, configuration := range configurations { frontendNames := sortedFrontendNamesForConfig(configuration) @@ -669,7 +670,12 @@ func (server *Server) loadConfig(configurations configs, globalConfiguration Glo // passing nil will use the roundtripper http.DefaultTransport rt := clientTLSRoundTripper(tlsConfig) - fwd, err := forward.New(forward.Logger(oxyLogger), forward.PassHostHeader(frontend.PassHostHeader), forward.RoundTripper(rt)) + fwd, err := forward.New( + forward.Logger(oxyLogger), + forward.PassHostHeader(frontend.PassHostHeader), + forward.RoundTripper(rt), + forward.ErrorHandler(errorHandler), + ) if err != nil { log.Errorf("Error creating forwarder for frontend %s: %v", frontendName, err) log.Errorf("Skipping frontend %s...", frontendName)