Retry middleware : store headers per attempts and propagate them when responding.
This commit is contained in:
parent
d7bd69714d
commit
fc8c24e987
4 changed files with 118 additions and 13 deletions
|
@ -73,8 +73,7 @@ func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
attempts := 1
|
attempts := 1
|
||||||
for {
|
for {
|
||||||
attemptsExhausted := attempts >= r.attempts
|
shouldRetry := attempts < r.attempts
|
||||||
shouldRetry := !attemptsExhausted
|
|
||||||
retryResponseWriter := newResponseWriter(rw, shouldRetry)
|
retryResponseWriter := newResponseWriter(rw, shouldRetry)
|
||||||
|
|
||||||
// Disable retries when the backend already received request data
|
// Disable retries when the backend already received request data
|
||||||
|
@ -118,6 +117,7 @@ type responseWriter interface {
|
||||||
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
|
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
|
||||||
responseWriter := &responseWriterWithoutCloseNotify{
|
responseWriter := &responseWriterWithoutCloseNotify{
|
||||||
responseWriter: rw,
|
responseWriter: rw,
|
||||||
|
headers: make(http.Header),
|
||||||
shouldRetry: shouldRetry,
|
shouldRetry: shouldRetry,
|
||||||
}
|
}
|
||||||
if _, ok := rw.(http.CloseNotifier); ok {
|
if _, ok := rw.(http.CloseNotifier); ok {
|
||||||
|
@ -130,6 +130,7 @@ func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter
|
||||||
|
|
||||||
type responseWriterWithoutCloseNotify struct {
|
type responseWriterWithoutCloseNotify struct {
|
||||||
responseWriter http.ResponseWriter
|
responseWriter http.ResponseWriter
|
||||||
|
headers http.Header
|
||||||
shouldRetry bool
|
shouldRetry bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,10 +143,7 @@ func (r *responseWriterWithoutCloseNotify) DisableRetries() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
|
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
|
||||||
if r.ShouldRetry() {
|
return r.headers
|
||||||
return make(http.Header)
|
|
||||||
}
|
|
||||||
return r.responseWriter.Header()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||||
|
@ -168,6 +166,16 @@ func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||||
if r.ShouldRetry() {
|
if r.ShouldRetry() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In that case retry case is set to false which means we at least managed
|
||||||
|
// to write headers to the backend : we are not going to perform any further retry.
|
||||||
|
// So it is now safe to alter current response headers with headers collected during
|
||||||
|
// the latest try before writing headers to client.
|
||||||
|
headers := r.responseWriter.Header()
|
||||||
|
for header, value := range r.headers {
|
||||||
|
headers[header] = value
|
||||||
|
}
|
||||||
|
|
||||||
r.responseWriter.WriteHeader(code)
|
r.responseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,10 @@ package retry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/http/httptrace"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -149,6 +151,50 @@ func TestRetryListeners(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
|
||||||
|
attempt := 0
|
||||||
|
expectedHeaderName := "X-Foo-Test-2"
|
||||||
|
expectedHeaderValue := "bar"
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
|
||||||
|
rw.Header().Add(headerName, expectedHeaderValue)
|
||||||
|
if attempt < 2 {
|
||||||
|
attempt++
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request has been successfully written to backend
|
||||||
|
trace := httptrace.ContextClientTrace(req.Context())
|
||||||
|
trace.WroteHeaders()
|
||||||
|
|
||||||
|
// And we decide to answer to client
|
||||||
|
rw.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))
|
||||||
|
|
||||||
|
headerValue := responseRecorder.Header().Get(expectedHeaderName)
|
||||||
|
|
||||||
|
// Validate if we have the correct header
|
||||||
|
if headerValue != expectedHeaderValue {
|
||||||
|
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we don't have headers from previous attempts
|
||||||
|
for i := 0; i < attempt; i++ {
|
||||||
|
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
|
||||||
|
headerValue = responseRecorder.Header().Get("headerName")
|
||||||
|
if headerValue != "" {
|
||||||
|
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
|
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
|
||||||
type countingRetryListener struct {
|
type countingRetryListener struct {
|
||||||
timesCalled int
|
timesCalled int
|
||||||
|
|
|
@ -44,9 +44,7 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
attempts := 1
|
attempts := 1
|
||||||
for {
|
for {
|
||||||
attemptsExhausted := attempts >= retry.attempts
|
shouldRetry := attempts < retry.attempts
|
||||||
|
|
||||||
shouldRetry := !attemptsExhausted
|
|
||||||
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
|
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
|
||||||
|
|
||||||
// Disable retries when the backend already received request data
|
// Disable retries when the backend already received request data
|
||||||
|
@ -99,6 +97,7 @@ type retryResponseWriter interface {
|
||||||
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
|
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
|
||||||
responseWriter := &retryResponseWriterWithoutCloseNotify{
|
responseWriter := &retryResponseWriterWithoutCloseNotify{
|
||||||
responseWriter: rw,
|
responseWriter: rw,
|
||||||
|
headers: make(http.Header),
|
||||||
shouldRetry: shouldRetry,
|
shouldRetry: shouldRetry,
|
||||||
}
|
}
|
||||||
if _, ok := rw.(http.CloseNotifier); ok {
|
if _, ok := rw.(http.CloseNotifier); ok {
|
||||||
|
@ -109,6 +108,7 @@ func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryRespo
|
||||||
|
|
||||||
type retryResponseWriterWithoutCloseNotify struct {
|
type retryResponseWriterWithoutCloseNotify struct {
|
||||||
responseWriter http.ResponseWriter
|
responseWriter http.ResponseWriter
|
||||||
|
headers http.Header
|
||||||
shouldRetry bool
|
shouldRetry bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,10 +121,7 @@ func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
|
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
|
||||||
if rr.ShouldRetry() {
|
return rr.headers
|
||||||
return make(http.Header)
|
|
||||||
}
|
|
||||||
return rr.responseWriter.Header()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||||
|
@ -147,6 +144,16 @@ func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||||
if rr.ShouldRetry() {
|
if rr.ShouldRetry() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In that case retry case is set to false which means we at least managed
|
||||||
|
// to write headers to the backend : we are not going to perform any further retry.
|
||||||
|
// So it is now safe to alter current response headers with headers collected during
|
||||||
|
// the latest try before writing headers to client.
|
||||||
|
headers := rr.responseWriter.Header()
|
||||||
|
for header, value := range rr.headers {
|
||||||
|
headers[header] = value
|
||||||
|
}
|
||||||
|
|
||||||
rr.responseWriter.WriteHeader(code)
|
rr.responseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package middlewares
|
package middlewares
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/http/httptrace"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -258,3 +260,45 @@ func TestRetryWithFlush(t *testing.T) {
|
||||||
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
|
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
|
||||||
|
attempt := 0
|
||||||
|
expectedHeaderName := "X-Foo-Test-2"
|
||||||
|
expectedHeaderValue := "bar"
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
|
||||||
|
rw.Header().Add(headerName, expectedHeaderValue)
|
||||||
|
if attempt < 2 {
|
||||||
|
attempt++
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request has been successfully written to backend
|
||||||
|
trace := httptrace.ContextClientTrace(req.Context())
|
||||||
|
trace.WroteHeaders()
|
||||||
|
|
||||||
|
// And we decide to answer to client
|
||||||
|
rw.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
retry := NewRetry(3, next, &countingRetryListener{})
|
||||||
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
retry.ServeHTTP(responseRecorder, &http.Request{})
|
||||||
|
|
||||||
|
headerValue := responseRecorder.Header().Get(expectedHeaderName)
|
||||||
|
|
||||||
|
// Validate if we have the correct header
|
||||||
|
if headerValue != expectedHeaderValue {
|
||||||
|
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we don't have headers from previous attempts
|
||||||
|
for i := 0; i < attempt; i++ {
|
||||||
|
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
|
||||||
|
headerValue = responseRecorder.Header().Get("headerName")
|
||||||
|
if headerValue != "" {
|
||||||
|
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue