From fc8c24e9873be595a52ca5b8d2e5eb8035cbc40c Mon Sep 17 00:00:00 2001 From: Julien Levesy Date: Mon, 7 Jan 2019 23:18:03 +0100 Subject: [PATCH] Retry middleware : store headers per attempts and propagate them when responding. --- middlewares/retry/retry.go | 20 +++++++++----- middlewares/retry/retry_test.go | 46 +++++++++++++++++++++++++++++++++ old/middlewares/retry.go | 21 ++++++++++----- old/middlewares/retry_test.go | 44 +++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 13 deletions(-) diff --git a/middlewares/retry/retry.go b/middlewares/retry/retry.go index dd23b185b..11cc337bf 100644 --- a/middlewares/retry/retry.go +++ b/middlewares/retry/retry.go @@ -73,8 +73,7 @@ func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) { attempts := 1 for { - attemptsExhausted := attempts >= r.attempts - shouldRetry := !attemptsExhausted + shouldRetry := attempts < r.attempts retryResponseWriter := newResponseWriter(rw, shouldRetry) // Disable retries when the backend already received request data @@ -118,6 +117,7 @@ type responseWriter interface { func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter { responseWriter := &responseWriterWithoutCloseNotify{ responseWriter: rw, + headers: make(http.Header), shouldRetry: shouldRetry, } if _, ok := rw.(http.CloseNotifier); ok { @@ -130,6 +130,7 @@ func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter type responseWriterWithoutCloseNotify struct { responseWriter http.ResponseWriter + headers http.Header shouldRetry bool } @@ -142,10 +143,7 @@ func (r *responseWriterWithoutCloseNotify) DisableRetries() { } func (r *responseWriterWithoutCloseNotify) Header() http.Header { - if r.ShouldRetry() { - return make(http.Header) - } - return r.responseWriter.Header() + return r.headers } func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) { @@ -168,6 +166,16 @@ func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) { if r.ShouldRetry() { return } + + // In that case retry case is set to false which means we at least managed + // to write headers to the backend : we are not going to perform any further retry. + // So it is now safe to alter current response headers with headers collected during + // the latest try before writing headers to client. + headers := r.responseWriter.Header() + for header, value := range r.headers { + headers[header] = value + } + r.responseWriter.WriteHeader(code) } diff --git a/middlewares/retry/retry_test.go b/middlewares/retry/retry_test.go index 7efba979e..8c77228f4 100644 --- a/middlewares/retry/retry_test.go +++ b/middlewares/retry/retry_test.go @@ -2,8 +2,10 @@ package retry import ( "context" + "fmt" "net/http" "net/http/httptest" + "net/http/httptrace" "strings" "testing" @@ -149,6 +151,50 @@ func TestRetryListeners(t *testing.T) { } } +func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { + attempt := 0 + expectedHeaderName := "X-Foo-Test-2" + expectedHeaderValue := "bar" + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + headerName := fmt.Sprintf("X-Foo-Test-%d", attempt) + rw.Header().Add(headerName, expectedHeaderValue) + if attempt < 2 { + attempt++ + return + } + + // Request has been successfully written to backend + trace := httptrace.ContextClientTrace(req.Context()) + trace.WroteHeaders() + + // And we decide to answer to client + rw.WriteHeader(http.StatusNoContent) + }) + + retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest") + require.NoError(t, err) + + responseRecorder := httptest.NewRecorder() + retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody)) + + headerValue := responseRecorder.Header().Get(expectedHeaderName) + + // Validate if we have the correct header + if headerValue != expectedHeaderValue { + t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue) + } + + // Validate that we don't have headers from previous attempts + for i := 0; i < attempt; i++ { + headerName := fmt.Sprintf("X-Foo-Test-%d", i) + headerValue = responseRecorder.Header().Get("headerName") + if headerValue != "" { + t.Errorf("Expected no value for header %s, got %s", headerName, headerValue) + } + } +} + // countingRetryListener is a Listener implementation to count the times the Retried fn is called. type countingRetryListener struct { timesCalled int diff --git a/old/middlewares/retry.go b/old/middlewares/retry.go index b5d3f29ab..b6407ac0e 100644 --- a/old/middlewares/retry.go +++ b/old/middlewares/retry.go @@ -44,9 +44,7 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { attempts := 1 for { - attemptsExhausted := attempts >= retry.attempts - - shouldRetry := !attemptsExhausted + shouldRetry := attempts < retry.attempts retryResponseWriter := newRetryResponseWriter(rw, shouldRetry) // Disable retries when the backend already received request data @@ -99,6 +97,7 @@ type retryResponseWriter interface { func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter { responseWriter := &retryResponseWriterWithoutCloseNotify{ responseWriter: rw, + headers: make(http.Header), shouldRetry: shouldRetry, } if _, ok := rw.(http.CloseNotifier); ok { @@ -109,6 +108,7 @@ func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryRespo type retryResponseWriterWithoutCloseNotify struct { responseWriter http.ResponseWriter + headers http.Header shouldRetry bool } @@ -121,10 +121,7 @@ func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() { } func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header { - if rr.ShouldRetry() { - return make(http.Header) - } - return rr.responseWriter.Header() + return rr.headers } func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) { @@ -147,6 +144,16 @@ func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) { if rr.ShouldRetry() { return } + + // In that case retry case is set to false which means we at least managed + // to write headers to the backend : we are not going to perform any further retry. + // So it is now safe to alter current response headers with headers collected during + // the latest try before writing headers to client. + headers := rr.responseWriter.Header() + for header, value := range rr.headers { + headers[header] = value + } + rr.responseWriter.WriteHeader(code) } diff --git a/old/middlewares/retry_test.go b/old/middlewares/retry_test.go index 792f3cdf7..e874d6222 100644 --- a/old/middlewares/retry_test.go +++ b/old/middlewares/retry_test.go @@ -1,8 +1,10 @@ package middlewares import ( + "fmt" "net/http" "net/http/httptest" + "net/http/httptrace" "strings" "testing" @@ -258,3 +260,45 @@ func TestRetryWithFlush(t *testing.T) { t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA") } } + +func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { + attempt := 0 + expectedHeaderName := "X-Foo-Test-2" + expectedHeaderValue := "bar" + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + headerName := fmt.Sprintf("X-Foo-Test-%d", attempt) + rw.Header().Add(headerName, expectedHeaderValue) + if attempt < 2 { + attempt++ + return + } + + // Request has been successfully written to backend + trace := httptrace.ContextClientTrace(req.Context()) + trace.WroteHeaders() + + // And we decide to answer to client + rw.WriteHeader(http.StatusNoContent) + }) + + retry := NewRetry(3, next, &countingRetryListener{}) + responseRecorder := httptest.NewRecorder() + retry.ServeHTTP(responseRecorder, &http.Request{}) + + headerValue := responseRecorder.Header().Get(expectedHeaderName) + + // Validate if we have the correct header + if headerValue != expectedHeaderValue { + t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue) + } + + // Validate that we don't have headers from previous attempts + for i := 0; i < attempt; i++ { + headerName := fmt.Sprintf("X-Foo-Test-%d", i) + headerValue = responseRecorder.Header().Get("headerName") + if headerValue != "" { + t.Errorf("Expected no value for header %s, got %s", headerName, headerValue) + } + } +}