avoid retries when any data was written to the backend

This commit is contained in:
Marco Jantke 2018-06-19 13:56:04 +02:00 committed by Traefiker Bot
parent 586ba31120
commit e31c85aace
5 changed files with 161 additions and 250 deletions

View file

@ -2,10 +2,10 @@ package middlewares
import ( import (
"bufio" "bufio"
"context"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace"
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
) )
@ -40,11 +40,24 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
attempts := 1 attempts := 1
for { for {
netErrorOccurred := false attemptsExhausted := attempts >= retry.attempts
// We pass in a pointer to netErrorOccurred so that we can set it to true on network errors // Websocket requests can't be retried at this point in time.
// when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler. // This is due to the fact that gorilla/websocket doesn't use the request
newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred) // context and so we don't get httptrace information.
retryResponseWriter := newRetryResponseWriter(rw, attempts >= retry.attempts, &netErrorOccurred) // Websocket clients should however retry on their own anyway.
shouldRetry := !attemptsExhausted && !isWebsocketRequest(r)
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(r.Context(), trace)
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx)) retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() { if !retryResponseWriter.ShouldRetry() {
@ -57,31 +70,6 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
} }
} }
// netErrorCtxKey is a custom type that is used as key for the context.
type netErrorCtxKey string
// defaultNetErrCtxKey is the actual key which value is used to record network errors.
var defaultNetErrCtxKey netErrorCtxKey = "NetErrCtxKey"
// NetErrorRecorder is an interface to record net errors.
type NetErrorRecorder interface {
// Record can be used to signal the retry middleware that an network error happened
// and therefore the request should be retried.
Record(ctx context.Context)
}
// DefaultNetErrorRecorder is the default NetErrorRecorder implementation.
type DefaultNetErrorRecorder struct{}
// Record is recording network errors by setting the context value for the defaultNetErrCtxKey to true.
func (DefaultNetErrorRecorder) Record(ctx context.Context) {
val := ctx.Value(defaultNetErrCtxKey)
if netErrorOccurred, isBoolPointer := val.(*bool); isBoolPointer {
*netErrorOccurred = true
}
}
// RetryListener is used to inform about retry attempts. // RetryListener is used to inform about retry attempts.
type RetryListener interface { type RetryListener interface {
// Retried will be called when a retry happens, with the request attempt passed to it. // Retried will be called when a retry happens, with the request attempt passed to it.
@ -104,13 +92,13 @@ type retryResponseWriter interface {
http.ResponseWriter http.ResponseWriter
http.Flusher http.Flusher
ShouldRetry() bool ShouldRetry() bool
DisableRetries()
} }
func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netErrorOccured *bool) retryResponseWriter { func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
responseWriter := &retryResponseWriterWithoutCloseNotify{ responseWriter := &retryResponseWriterWithoutCloseNotify{
responseWriter: rw, responseWriter: rw,
attemptsExhausted: attemptsExhausted, shouldRetry: shouldRetry,
netErrorOccured: netErrorOccured,
} }
if _, ok := rw.(http.CloseNotifier); ok { if _, ok := rw.(http.CloseNotifier); ok {
return &retryResponseWriterWithCloseNotify{responseWriter} return &retryResponseWriterWithCloseNotify{responseWriter}
@ -119,13 +107,16 @@ func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netE
} }
type retryResponseWriterWithoutCloseNotify struct { type retryResponseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter responseWriter http.ResponseWriter
attemptsExhausted bool shouldRetry bool
netErrorOccured *bool
} }
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool { func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
return *rr.netErrorOccured && !rr.attemptsExhausted return rr.shouldRetry
}
func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
rr.shouldRetry = false
} }
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header { func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
@ -143,6 +134,15 @@ func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error)
} }
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) { func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that rr.ShouldRetry()
// will never return true in case there was a connetion established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
rr.DisableRetries()
}
if rr.ShouldRetry() { if rr.ShouldRetry() {
return return
} }

View file

@ -1,91 +1,155 @@
package middlewares package middlewares
import ( import (
"context"
"fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/containous/traefik/testhelpers"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
) )
func TestRetry(t *testing.T) { func TestRetry(t *testing.T) {
testCases := []struct { testCases := []struct {
failAtCalls []int desc string
attempts int maxRequestAttempts int
responseStatus int wantRetryAttempts int
listener *countingRetryListener wantResponseStatus int
retriedCount int amountFaultyEndpoints int
isWebsocketHandshakeRequest bool
}{ }{
{ {
failAtCalls: []int{1, 2}, desc: "no retry on success",
attempts: 3, maxRequestAttempts: 1,
responseStatus: http.StatusOK, wantRetryAttempts: 0,
listener: &countingRetryListener{}, wantResponseStatus: http.StatusOK,
retriedCount: 2, amountFaultyEndpoints: 0,
}, },
{ {
failAtCalls: []int{1, 2}, desc: "no retry when max request attempts is one",
attempts: 2, maxRequestAttempts: 1,
responseStatus: http.StatusBadGateway, wantRetryAttempts: 0,
listener: &countingRetryListener{}, wantResponseStatus: http.StatusInternalServerError,
retriedCount: 1, amountFaultyEndpoints: 1,
},
{
desc: "one retry when one server is faulty",
maxRequestAttempts: 2,
wantRetryAttempts: 1,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 1,
},
{
desc: "two retries when two servers are faulty",
maxRequestAttempts: 3,
wantRetryAttempts: 2,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 2,
},
{
desc: "max attempts exhausted delivers the 5xx response",
maxRequestAttempts: 3,
wantRetryAttempts: 2,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 3,
},
{
desc: "websocket request should not be retried",
maxRequestAttempts: 3,
wantRetryAttempts: 0,
wantResponseStatus: http.StatusBadGateway,
amountFaultyEndpoints: 1,
isWebsocketHandshakeRequest: true,
}, },
} }
for _, tc := range testCases { backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// bind tc locally rw.WriteHeader(http.StatusOK)
tc := tc rw.Write([]byte("OK"))
tcName := fmt.Sprintf("FailAtCalls(%v) RetryAttempts(%v)", tc.failAtCalls, tc.attempts) }))
t.Run(tcName, func(t *testing.T) { forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
for _, tc := range testCases {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel() t.Parallel()
var httpHandler http.Handler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls, netErrorRecorder: &DefaultNetErrorRecorder{}} loadBalancer, err := roundrobin.New(forwarder)
httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener) if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
basePort := 33444
for i := 0; i < tc.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
}
// add the functioning server to the end of the load balancer list
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
retryListener := &countingRetryListener{}
retry := NewRetry(tc.maxRequestAttempts, loadBalancer, retryListener)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "http://localhost:3000/ok", ioutil.NopCloser(nil)) req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
if err != nil {
t.Fatalf("could not create request: %+v", err) if tc.isWebsocketHandshakeRequest {
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", "websocket")
} }
httpHandler.ServeHTTP(recorder, req) retry.ServeHTTP(recorder, req)
if tc.responseStatus != recorder.Code { if tc.wantResponseStatus != recorder.Code {
t.Errorf("wrong status code %d, want %d", recorder.Code, tc.responseStatus) t.Errorf("got status code %d, want %d", recorder.Code, tc.wantResponseStatus)
} }
if tc.retriedCount != tc.listener.timesCalled { if tc.wantRetryAttempts != retryListener.timesCalled {
t.Errorf("RetryListener called %d times, want %d times", tc.listener.timesCalled, tc.retriedCount) t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, tc.wantRetryAttempts)
} }
}) })
} }
} }
func TestDefaultNetErrorRecorderSuccess(t *testing.T) { func TestRetryEmptyServerList(t *testing.T) {
boolNetErrorOccurred := false forwarder, err := forward.New()
recorder := DefaultNetErrorRecorder{} if err != nil {
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &boolNetErrorOccurred)) t.Fatalf("Error creating forwarder: %s", err)
if !boolNetErrorOccurred {
t.Errorf("got %v after recording net error, wanted %v", boolNetErrorOccurred, true)
} }
}
func TestDefaultNetErrorRecorderInvalidValueType(t *testing.T) { loadBalancer, err := roundrobin.New(forwarder)
stringNetErrorOccured := "nonsense" if err != nil {
recorder := DefaultNetErrorRecorder{} t.Fatalf("Error creating load balancer: %s", err)
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &stringNetErrorOccured))
if stringNetErrorOccured != "nonsense" {
t.Errorf("got %v after recording net error, wanted %v", stringNetErrorOccured, "nonsense")
} }
}
func TestDefaultNetErrorRecorderNilValue(t *testing.T) { // The EmptyBackendHandler middleware ensures that there is a 503
nilNetErrorOccured := interface{}(nil) // response status set when there is no backend server in the pool.
recorder := DefaultNetErrorRecorder{} next := NewEmptyBackendHandler(loadBalancer)
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &nilNetErrorOccured))
if nilNetErrorOccured != interface{}(nil) { retryListener := &countingRetryListener{}
t.Errorf("got %v after recording net error, wanted %v", nilNetErrorOccured, interface{}(nil)) retry := NewRetry(3, next, retryListener)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
const wantResponseStatus = http.StatusServiceUnavailable
if wantResponseStatus != recorder.Code {
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
}
const wantRetryAttempts = 0
if wantRetryAttempts != retryListener.timesCalled {
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
} }
} }
@ -99,33 +163,11 @@ func TestRetryListeners(t *testing.T) {
for _, retryListener := range retryListeners { for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener) listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 { if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d times, want %d", listener.timesCalled, 2) t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
} }
} }
} }
// networkFailingHTTPHandler is an http.Handler implementation you can use to test retries.
type networkFailingHTTPHandler struct {
netErrorRecorder NetErrorRecorder
failAtCalls []int
callNumber int
}
func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.callNumber++
for _, failAtCall := range handler.failAtCalls {
if handler.callNumber == failAtCall {
handler.netErrorRecorder.Record(r.Context())
w.WriteHeader(http.StatusBadGateway)
return
}
}
w.WriteHeader(http.StatusOK)
}
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called. // countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
type countingRetryListener struct { type countingRetryListener struct {
timesCalled int timesCalled int

View file

@ -1,40 +0,0 @@
package server
import (
"io"
"net"
"net/http"
"github.com/containous/traefik/middlewares"
)
// RecordingErrorHandler is an error handler, implementing the vulcand/oxy
// error handler interface, which is recording network errors by using the netErrorRecorder.
// In addition it sets a proper HTTP status code and body, depending on the type of error occurred.
type RecordingErrorHandler struct {
netErrorRecorder middlewares.NetErrorRecorder
}
// NewRecordingErrorHandler creates and returns a new instance of RecordingErrorHandler.
func NewRecordingErrorHandler(recorder middlewares.NetErrorRecorder) *RecordingErrorHandler {
return &RecordingErrorHandler{recorder}
}
func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
statusCode := http.StatusInternalServerError
if e, ok := err.(net.Error); ok {
eh.netErrorRecorder.Record(req.Context())
if e.Timeout() {
statusCode = http.StatusGatewayTimeout
} else {
statusCode = http.StatusBadGateway
}
} else if err == io.EOF {
eh.netErrorRecorder.Record(req.Context())
statusCode = http.StatusBadGateway
}
w.WriteHeader(statusCode)
w.Write([]byte(http.StatusText(statusCode)))
}

View file

@ -1,88 +0,0 @@
package server
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
)
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func TestServeHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantHTTPStatus int
wantNetErrRecorded bool
}{
{
name: "net.Error",
err: net.UnknownNetworkError("any network error"),
wantHTTPStatus: http.StatusBadGateway,
wantNetErrRecorded: true,
},
{
name: "net.Error with Timeout",
err: &timeoutError{},
wantHTTPStatus: http.StatusGatewayTimeout,
wantNetErrRecorded: true,
},
{
name: "io.EOF",
err: io.EOF,
wantHTTPStatus: http.StatusBadGateway,
wantNetErrRecorded: true,
},
{
name: "custom error",
err: errors.New("any error"),
wantHTTPStatus: http.StatusInternalServerError,
wantNetErrRecorded: false,
},
{
name: "nil error",
err: nil,
wantHTTPStatus: http.StatusInternalServerError,
wantNetErrRecorded: false,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
recorder := httptest.NewRecorder()
errorRecorder := &netErrorRecorder{}
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/any", nil)
recordingErrorHandler := NewRecordingErrorHandler(errorRecorder)
recordingErrorHandler.ServeHTTP(recorder, req, test.err)
if recorder.Code != test.wantHTTPStatus {
t.Errorf("got HTTP status code %v, wanted %v", recorder.Code, test.wantHTTPStatus)
}
if errorRecorder.netErrorWasRecorded != test.wantNetErrRecorded {
t.Errorf("net error recording wrong, got %v wanted %v", errorRecorder.netErrorWasRecorded, test.wantNetErrRecorded)
}
})
}
}
type netErrorRecorder struct {
netErrorWasRecorded bool
}
func (recorder *netErrorRecorder) Record(ctx context.Context) {
recorder.netErrorWasRecorded = true
}

View file

@ -23,7 +23,6 @@ import (
"github.com/eapache/channels" "github.com/eapache/channels"
"github.com/urfave/negroni" "github.com/urfave/negroni"
"github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
) )
// loadConfiguration manages dynamically frontends, backends and TLS configurations // loadConfiguration manages dynamically frontends, backends and TLS configurations
@ -80,7 +79,6 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
} }
serverEntryPoints := s.buildServerEntryPoints() serverEntryPoints := s.buildServerEntryPoints()
errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{})
backendsHandlers := map[string]http.Handler{} backendsHandlers := map[string]http.Handler{}
backendsHealthCheck := map[string]*healthcheck.BackendConfig{} backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
@ -92,7 +90,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
for _, frontendName := range frontendNames { for _, frontendName := range frontendNames {
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config, frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
redirectHandlers, serverEntryPoints, errorHandler, redirectHandlers, serverEntryPoints,
backendsHandlers, backendsHealthCheck) backendsHandlers, backendsHealthCheck)
if err != nil { if err != nil {
log.Errorf("%v. Skipping frontend %s...", err, frontendName) log.Errorf("%v. Skipping frontend %s...", err, frontendName)
@ -131,7 +129,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
func (s *Server) loadFrontendConfig( func (s *Server) loadFrontendConfig(
providerName string, frontendName string, config *types.Configuration, providerName string, frontendName string, config *types.Configuration,
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler, redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint,
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig, backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
) ([]handlerPostConfig, error) { ) ([]handlerPostConfig, error) {
@ -170,7 +168,7 @@ func (s *Server) loadFrontendConfig(
postConfigs = append(postConfigs, postConfig) postConfigs = append(postConfigs, postConfig)
} }
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier) fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, responseModifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err) return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
} }
@ -222,7 +220,7 @@ func (s *Server) loadFrontendConfig(
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint, func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
frontendName string, frontend *types.Frontend, frontendName string, frontend *types.Frontend,
errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) { responseModifier modifyResponse) (http.Handler, error) {
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS) roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
if err != nil { if err != nil {
@ -239,7 +237,6 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration
forward.Stream(true), forward.Stream(true),
forward.PassHostHeader(frontend.PassHostHeader), forward.PassHostHeader(frontend.PassHostHeader),
forward.RoundTripper(roundTripper), forward.RoundTripper(roundTripper),
forward.ErrorHandler(errorHandler),
forward.Rewriter(rewriter), forward.Rewriter(rewriter),
forward.ResponseModifier(responseModifier), forward.ResponseModifier(responseModifier),
forward.BufferPool(s.bufferPool), forward.BufferPool(s.bufferPool),