avoid retries when any data was written to the backend
This commit is contained in:
parent
586ba31120
commit
e31c85aace
5 changed files with 161 additions and 250 deletions
|
@ -2,10 +2,10 @@ package middlewares
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
|
||||||
"github.com/containous/traefik/log"
|
"github.com/containous/traefik/log"
|
||||||
)
|
)
|
||||||
|
@ -40,11 +40,24 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
attempts := 1
|
attempts := 1
|
||||||
for {
|
for {
|
||||||
netErrorOccurred := false
|
attemptsExhausted := attempts >= retry.attempts
|
||||||
// We pass in a pointer to netErrorOccurred so that we can set it to true on network errors
|
// Websocket requests can't be retried at this point in time.
|
||||||
// when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler.
|
// This is due to the fact that gorilla/websocket doesn't use the request
|
||||||
newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred)
|
// context and so we don't get httptrace information.
|
||||||
retryResponseWriter := newRetryResponseWriter(rw, attempts >= retry.attempts, &netErrorOccurred)
|
// Websocket clients should however retry on their own anyway.
|
||||||
|
shouldRetry := !attemptsExhausted && !isWebsocketRequest(r)
|
||||||
|
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
|
||||||
|
|
||||||
|
// Disable retries when the backend already received request data
|
||||||
|
trace := &httptrace.ClientTrace{
|
||||||
|
WroteHeaders: func() {
|
||||||
|
retryResponseWriter.DisableRetries()
|
||||||
|
},
|
||||||
|
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||||
|
retryResponseWriter.DisableRetries()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCtx := httptrace.WithClientTrace(r.Context(), trace)
|
||||||
|
|
||||||
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
|
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
|
||||||
if !retryResponseWriter.ShouldRetry() {
|
if !retryResponseWriter.ShouldRetry() {
|
||||||
|
@ -57,31 +70,6 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// netErrorCtxKey is a custom type that is used as key for the context.
|
|
||||||
type netErrorCtxKey string
|
|
||||||
|
|
||||||
// defaultNetErrCtxKey is the actual key which value is used to record network errors.
|
|
||||||
var defaultNetErrCtxKey netErrorCtxKey = "NetErrCtxKey"
|
|
||||||
|
|
||||||
// NetErrorRecorder is an interface to record net errors.
|
|
||||||
type NetErrorRecorder interface {
|
|
||||||
// Record can be used to signal the retry middleware that an network error happened
|
|
||||||
// and therefore the request should be retried.
|
|
||||||
Record(ctx context.Context)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultNetErrorRecorder is the default NetErrorRecorder implementation.
|
|
||||||
type DefaultNetErrorRecorder struct{}
|
|
||||||
|
|
||||||
// Record is recording network errors by setting the context value for the defaultNetErrCtxKey to true.
|
|
||||||
func (DefaultNetErrorRecorder) Record(ctx context.Context) {
|
|
||||||
val := ctx.Value(defaultNetErrCtxKey)
|
|
||||||
|
|
||||||
if netErrorOccurred, isBoolPointer := val.(*bool); isBoolPointer {
|
|
||||||
*netErrorOccurred = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RetryListener is used to inform about retry attempts.
|
// RetryListener is used to inform about retry attempts.
|
||||||
type RetryListener interface {
|
type RetryListener interface {
|
||||||
// Retried will be called when a retry happens, with the request attempt passed to it.
|
// Retried will be called when a retry happens, with the request attempt passed to it.
|
||||||
|
@ -104,13 +92,13 @@ type retryResponseWriter interface {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
http.Flusher
|
http.Flusher
|
||||||
ShouldRetry() bool
|
ShouldRetry() bool
|
||||||
|
DisableRetries()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netErrorOccured *bool) retryResponseWriter {
|
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
|
||||||
responseWriter := &retryResponseWriterWithoutCloseNotify{
|
responseWriter := &retryResponseWriterWithoutCloseNotify{
|
||||||
responseWriter: rw,
|
responseWriter: rw,
|
||||||
attemptsExhausted: attemptsExhausted,
|
shouldRetry: shouldRetry,
|
||||||
netErrorOccured: netErrorOccured,
|
|
||||||
}
|
}
|
||||||
if _, ok := rw.(http.CloseNotifier); ok {
|
if _, ok := rw.(http.CloseNotifier); ok {
|
||||||
return &retryResponseWriterWithCloseNotify{responseWriter}
|
return &retryResponseWriterWithCloseNotify{responseWriter}
|
||||||
|
@ -120,12 +108,15 @@ func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netE
|
||||||
|
|
||||||
type retryResponseWriterWithoutCloseNotify struct {
|
type retryResponseWriterWithoutCloseNotify struct {
|
||||||
responseWriter http.ResponseWriter
|
responseWriter http.ResponseWriter
|
||||||
attemptsExhausted bool
|
shouldRetry bool
|
||||||
netErrorOccured *bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
|
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
|
||||||
return *rr.netErrorOccured && !rr.attemptsExhausted
|
return rr.shouldRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
|
||||||
|
rr.shouldRetry = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
|
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
|
||||||
|
@ -143,6 +134,15 @@ func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
|
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||||
|
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
|
||||||
|
// We get a 503 HTTP Status Code when there is no backend server in the pool
|
||||||
|
// to which the request could be sent. Also, note that rr.ShouldRetry()
|
||||||
|
// will never return true in case there was a connetion established to
|
||||||
|
// the backend server and so we can be sure that the 503 was produced
|
||||||
|
// inside Traefik already and we don't have to retry in this cases.
|
||||||
|
rr.DisableRetries()
|
||||||
|
}
|
||||||
|
|
||||||
if rr.ShouldRetry() {
|
if rr.ShouldRetry() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,91 +1,155 @@
|
||||||
package middlewares
|
package middlewares
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/containous/traefik/testhelpers"
|
||||||
|
"github.com/vulcand/oxy/forward"
|
||||||
|
"github.com/vulcand/oxy/roundrobin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRetry(t *testing.T) {
|
func TestRetry(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
failAtCalls []int
|
desc string
|
||||||
attempts int
|
maxRequestAttempts int
|
||||||
responseStatus int
|
wantRetryAttempts int
|
||||||
listener *countingRetryListener
|
wantResponseStatus int
|
||||||
retriedCount int
|
amountFaultyEndpoints int
|
||||||
|
isWebsocketHandshakeRequest bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
failAtCalls: []int{1, 2},
|
desc: "no retry on success",
|
||||||
attempts: 3,
|
maxRequestAttempts: 1,
|
||||||
responseStatus: http.StatusOK,
|
wantRetryAttempts: 0,
|
||||||
listener: &countingRetryListener{},
|
wantResponseStatus: http.StatusOK,
|
||||||
retriedCount: 2,
|
amountFaultyEndpoints: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
failAtCalls: []int{1, 2},
|
desc: "no retry when max request attempts is one",
|
||||||
attempts: 2,
|
maxRequestAttempts: 1,
|
||||||
responseStatus: http.StatusBadGateway,
|
wantRetryAttempts: 0,
|
||||||
listener: &countingRetryListener{},
|
wantResponseStatus: http.StatusInternalServerError,
|
||||||
retriedCount: 1,
|
amountFaultyEndpoints: 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "one retry when one server is faulty",
|
||||||
|
maxRequestAttempts: 2,
|
||||||
|
wantRetryAttempts: 1,
|
||||||
|
wantResponseStatus: http.StatusOK,
|
||||||
|
amountFaultyEndpoints: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "two retries when two servers are faulty",
|
||||||
|
maxRequestAttempts: 3,
|
||||||
|
wantRetryAttempts: 2,
|
||||||
|
wantResponseStatus: http.StatusOK,
|
||||||
|
amountFaultyEndpoints: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "max attempts exhausted delivers the 5xx response",
|
||||||
|
maxRequestAttempts: 3,
|
||||||
|
wantRetryAttempts: 2,
|
||||||
|
wantResponseStatus: http.StatusInternalServerError,
|
||||||
|
amountFaultyEndpoints: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "websocket request should not be retried",
|
||||||
|
maxRequestAttempts: 3,
|
||||||
|
wantRetryAttempts: 0,
|
||||||
|
wantResponseStatus: http.StatusBadGateway,
|
||||||
|
amountFaultyEndpoints: 1,
|
||||||
|
isWebsocketHandshakeRequest: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
rw.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
forwarder, err := forward.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating forwarder: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
// bind tc locally
|
|
||||||
tc := tc
|
tc := tc
|
||||||
tcName := fmt.Sprintf("FailAtCalls(%v) RetryAttempts(%v)", tc.failAtCalls, tc.attempts)
|
|
||||||
|
|
||||||
t.Run(tcName, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var httpHandler http.Handler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls, netErrorRecorder: &DefaultNetErrorRecorder{}}
|
loadBalancer, err := roundrobin.New(forwarder)
|
||||||
httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener)
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating load balancer: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
basePort := 33444
|
||||||
|
for i := 0; i < tc.amountFaultyEndpoints; i++ {
|
||||||
|
// 192.0.2.0 is a non-routable IP for testing purposes.
|
||||||
|
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
|
||||||
|
// We only use the port specification here because the URL is used as identifier
|
||||||
|
// in the load balancer and using the exact same URL would not add a new server.
|
||||||
|
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the functioning server to the end of the load balancer list
|
||||||
|
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||||
|
|
||||||
|
retryListener := &countingRetryListener{}
|
||||||
|
retry := NewRetry(tc.maxRequestAttempts, loadBalancer, retryListener)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:3000/ok", ioutil.NopCloser(nil))
|
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("could not create request: %+v", err)
|
if tc.isWebsocketHandshakeRequest {
|
||||||
|
req.Header.Add("Connection", "Upgrade")
|
||||||
|
req.Header.Add("Upgrade", "websocket")
|
||||||
}
|
}
|
||||||
|
|
||||||
httpHandler.ServeHTTP(recorder, req)
|
retry.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
if tc.responseStatus != recorder.Code {
|
if tc.wantResponseStatus != recorder.Code {
|
||||||
t.Errorf("wrong status code %d, want %d", recorder.Code, tc.responseStatus)
|
t.Errorf("got status code %d, want %d", recorder.Code, tc.wantResponseStatus)
|
||||||
}
|
}
|
||||||
if tc.retriedCount != tc.listener.timesCalled {
|
if tc.wantRetryAttempts != retryListener.timesCalled {
|
||||||
t.Errorf("RetryListener called %d times, want %d times", tc.listener.timesCalled, tc.retriedCount)
|
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, tc.wantRetryAttempts)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultNetErrorRecorderSuccess(t *testing.T) {
|
func TestRetryEmptyServerList(t *testing.T) {
|
||||||
boolNetErrorOccurred := false
|
forwarder, err := forward.New()
|
||||||
recorder := DefaultNetErrorRecorder{}
|
if err != nil {
|
||||||
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &boolNetErrorOccurred))
|
t.Fatalf("Error creating forwarder: %s", err)
|
||||||
if !boolNetErrorOccurred {
|
|
||||||
t.Errorf("got %v after recording net error, wanted %v", boolNetErrorOccurred, true)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultNetErrorRecorderInvalidValueType(t *testing.T) {
|
loadBalancer, err := roundrobin.New(forwarder)
|
||||||
stringNetErrorOccured := "nonsense"
|
if err != nil {
|
||||||
recorder := DefaultNetErrorRecorder{}
|
t.Fatalf("Error creating load balancer: %s", err)
|
||||||
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &stringNetErrorOccured))
|
|
||||||
if stringNetErrorOccured != "nonsense" {
|
|
||||||
t.Errorf("got %v after recording net error, wanted %v", stringNetErrorOccured, "nonsense")
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultNetErrorRecorderNilValue(t *testing.T) {
|
// The EmptyBackendHandler middleware ensures that there is a 503
|
||||||
nilNetErrorOccured := interface{}(nil)
|
// response status set when there is no backend server in the pool.
|
||||||
recorder := DefaultNetErrorRecorder{}
|
next := NewEmptyBackendHandler(loadBalancer)
|
||||||
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &nilNetErrorOccured))
|
|
||||||
if nilNetErrorOccured != interface{}(nil) {
|
retryListener := &countingRetryListener{}
|
||||||
t.Errorf("got %v after recording net error, wanted %v", nilNetErrorOccured, interface{}(nil))
|
retry := NewRetry(3, next, retryListener)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||||
|
|
||||||
|
retry.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
const wantResponseStatus = http.StatusServiceUnavailable
|
||||||
|
if wantResponseStatus != recorder.Code {
|
||||||
|
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
|
||||||
|
}
|
||||||
|
const wantRetryAttempts = 0
|
||||||
|
if wantRetryAttempts != retryListener.timesCalled {
|
||||||
|
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,33 +163,11 @@ func TestRetryListeners(t *testing.T) {
|
||||||
for _, retryListener := range retryListeners {
|
for _, retryListener := range retryListeners {
|
||||||
listener := retryListener.(*countingRetryListener)
|
listener := retryListener.(*countingRetryListener)
|
||||||
if listener.timesCalled != 2 {
|
if listener.timesCalled != 2 {
|
||||||
t.Errorf("retry listener was called %d times, want %d", listener.timesCalled, 2)
|
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// networkFailingHTTPHandler is an http.Handler implementation you can use to test retries.
|
|
||||||
type networkFailingHTTPHandler struct {
|
|
||||||
netErrorRecorder NetErrorRecorder
|
|
||||||
failAtCalls []int
|
|
||||||
callNumber int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
handler.callNumber++
|
|
||||||
|
|
||||||
for _, failAtCall := range handler.failAtCalls {
|
|
||||||
if handler.callNumber == failAtCall {
|
|
||||||
handler.netErrorRecorder.Record(r.Context())
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
|
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
|
||||||
type countingRetryListener struct {
|
type countingRetryListener struct {
|
||||||
timesCalled int
|
timesCalled int
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/containous/traefik/middlewares"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RecordingErrorHandler is an error handler, implementing the vulcand/oxy
|
|
||||||
// error handler interface, which is recording network errors by using the netErrorRecorder.
|
|
||||||
// In addition it sets a proper HTTP status code and body, depending on the type of error occurred.
|
|
||||||
type RecordingErrorHandler struct {
|
|
||||||
netErrorRecorder middlewares.NetErrorRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRecordingErrorHandler creates and returns a new instance of RecordingErrorHandler.
|
|
||||||
func NewRecordingErrorHandler(recorder middlewares.NetErrorRecorder) *RecordingErrorHandler {
|
|
||||||
return &RecordingErrorHandler{recorder}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
|
|
||||||
statusCode := http.StatusInternalServerError
|
|
||||||
|
|
||||||
if e, ok := err.(net.Error); ok {
|
|
||||||
eh.netErrorRecorder.Record(req.Context())
|
|
||||||
if e.Timeout() {
|
|
||||||
statusCode = http.StatusGatewayTimeout
|
|
||||||
} else {
|
|
||||||
statusCode = http.StatusBadGateway
|
|
||||||
}
|
|
||||||
} else if err == io.EOF {
|
|
||||||
eh.netErrorRecorder.Record(req.Context())
|
|
||||||
statusCode = http.StatusBadGateway
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(statusCode)
|
|
||||||
w.Write([]byte(http.StatusText(statusCode)))
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type timeoutError struct{}
|
|
||||||
|
|
||||||
func (e *timeoutError) Error() string { return "i/o timeout" }
|
|
||||||
func (e *timeoutError) Timeout() bool { return true }
|
|
||||||
func (e *timeoutError) Temporary() bool { return true }
|
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
wantHTTPStatus int
|
|
||||||
wantNetErrRecorded bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "net.Error",
|
|
||||||
err: net.UnknownNetworkError("any network error"),
|
|
||||||
wantHTTPStatus: http.StatusBadGateway,
|
|
||||||
wantNetErrRecorded: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "net.Error with Timeout",
|
|
||||||
err: &timeoutError{},
|
|
||||||
wantHTTPStatus: http.StatusGatewayTimeout,
|
|
||||||
wantNetErrRecorded: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "io.EOF",
|
|
||||||
err: io.EOF,
|
|
||||||
wantHTTPStatus: http.StatusBadGateway,
|
|
||||||
wantNetErrRecorded: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "custom error",
|
|
||||||
err: errors.New("any error"),
|
|
||||||
wantHTTPStatus: http.StatusInternalServerError,
|
|
||||||
wantNetErrRecorded: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil error",
|
|
||||||
err: nil,
|
|
||||||
wantHTTPStatus: http.StatusInternalServerError,
|
|
||||||
wantNetErrRecorded: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
test := test
|
|
||||||
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
errorRecorder := &netErrorRecorder{}
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/any", nil)
|
|
||||||
|
|
||||||
recordingErrorHandler := NewRecordingErrorHandler(errorRecorder)
|
|
||||||
recordingErrorHandler.ServeHTTP(recorder, req, test.err)
|
|
||||||
|
|
||||||
if recorder.Code != test.wantHTTPStatus {
|
|
||||||
t.Errorf("got HTTP status code %v, wanted %v", recorder.Code, test.wantHTTPStatus)
|
|
||||||
}
|
|
||||||
if errorRecorder.netErrorWasRecorded != test.wantNetErrRecorded {
|
|
||||||
t.Errorf("net error recording wrong, got %v wanted %v", errorRecorder.netErrorWasRecorded, test.wantNetErrRecorded)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type netErrorRecorder struct {
|
|
||||||
netErrorWasRecorded bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (recorder *netErrorRecorder) Record(ctx context.Context) {
|
|
||||||
recorder.netErrorWasRecorded = true
|
|
||||||
}
|
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"github.com/eapache/channels"
|
"github.com/eapache/channels"
|
||||||
"github.com/urfave/negroni"
|
"github.com/urfave/negroni"
|
||||||
"github.com/vulcand/oxy/forward"
|
"github.com/vulcand/oxy/forward"
|
||||||
"github.com/vulcand/oxy/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// loadConfiguration manages dynamically frontends, backends and TLS configurations
|
// loadConfiguration manages dynamically frontends, backends and TLS configurations
|
||||||
|
@ -80,7 +79,6 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
||||||
}
|
}
|
||||||
|
|
||||||
serverEntryPoints := s.buildServerEntryPoints()
|
serverEntryPoints := s.buildServerEntryPoints()
|
||||||
errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{})
|
|
||||||
|
|
||||||
backendsHandlers := map[string]http.Handler{}
|
backendsHandlers := map[string]http.Handler{}
|
||||||
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
|
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
|
||||||
|
@ -92,7 +90,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
||||||
|
|
||||||
for _, frontendName := range frontendNames {
|
for _, frontendName := range frontendNames {
|
||||||
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
|
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
|
||||||
redirectHandlers, serverEntryPoints, errorHandler,
|
redirectHandlers, serverEntryPoints,
|
||||||
backendsHandlers, backendsHealthCheck)
|
backendsHandlers, backendsHealthCheck)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
|
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
|
||||||
|
@ -131,7 +129,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
||||||
|
|
||||||
func (s *Server) loadFrontendConfig(
|
func (s *Server) loadFrontendConfig(
|
||||||
providerName string, frontendName string, config *types.Configuration,
|
providerName string, frontendName string, config *types.Configuration,
|
||||||
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler,
|
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint,
|
||||||
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
|
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
|
||||||
) ([]handlerPostConfig, error) {
|
) ([]handlerPostConfig, error) {
|
||||||
|
|
||||||
|
@ -170,7 +168,7 @@ func (s *Server) loadFrontendConfig(
|
||||||
postConfigs = append(postConfigs, postConfig)
|
postConfigs = append(postConfigs, postConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier)
|
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, responseModifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
|
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
|
||||||
}
|
}
|
||||||
|
@ -222,7 +220,7 @@ func (s *Server) loadFrontendConfig(
|
||||||
|
|
||||||
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
|
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
|
||||||
frontendName string, frontend *types.Frontend,
|
frontendName string, frontend *types.Frontend,
|
||||||
errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) {
|
responseModifier modifyResponse) (http.Handler, error) {
|
||||||
|
|
||||||
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
|
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -239,7 +237,6 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration
|
||||||
forward.Stream(true),
|
forward.Stream(true),
|
||||||
forward.PassHostHeader(frontend.PassHostHeader),
|
forward.PassHostHeader(frontend.PassHostHeader),
|
||||||
forward.RoundTripper(roundTripper),
|
forward.RoundTripper(roundTripper),
|
||||||
forward.ErrorHandler(errorHandler),
|
|
||||||
forward.Rewriter(rewriter),
|
forward.Rewriter(rewriter),
|
||||||
forward.ResponseModifier(responseModifier),
|
forward.ResponseModifier(responseModifier),
|
||||||
forward.BufferPool(s.bufferPool),
|
forward.BufferPool(s.bufferPool),
|
||||||
|
|
Loading…
Reference in a new issue