diff --git a/pkg/middlewares/auth/forward.go b/pkg/middlewares/auth/forward.go index 4bd5bbaf6..a39abbea4 100644 --- a/pkg/middlewares/auth/forward.go +++ b/pkg/middlewares/auth/forward.go @@ -2,12 +2,12 @@ package auth import ( "context" - "crypto/tls" "fmt" "io/ioutil" "net" "net/http" "strings" + "time" "github.com/containous/traefik/v2/pkg/config/dynamic" "github.com/containous/traefik/v2/pkg/log" @@ -29,7 +29,7 @@ type forwardAuth struct { authResponseHeaders []string next http.Handler name string - tlsConfig *tls.Config + client http.Client trustForwardHeader bool } @@ -45,13 +45,23 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu trustForwardHeader: config.TrustForwardHeader, } + // Ensure our request client does not follow redirects + fa.client = http.Client{ + CheckRedirect: func(r *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: 30 * time.Second, + } + if config.TLS != nil { tlsConfig, err := config.TLS.CreateTLSConfig() if err != nil { return nil, err } - fa.tlsConfig = tlsConfig + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = tlsConfig + fa.client.Transport = tr } return fa, nil @@ -64,19 +74,6 @@ func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) { func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), fa.name, forwardedTypeName)) - // Ensure our request client does not follow redirects - httpClient := http.Client{ - CheckRedirect: func(r *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - if fa.tlsConfig != nil { - httpClient.Transport = &http.Transport{ - TLSClientConfig: fa.tlsConfig, - } - } - forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil) tracing.LogRequest(tracing.GetSpan(req), forwardReq) if err != nil { @@ -94,7 +91,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { writeHeader(req, forwardReq, fa.trustForwardHeader) - forwardResponse, forwardErr := httpClient.Do(forwardReq) + forwardResponse, forwardErr := fa.client.Do(forwardReq) if forwardErr != nil { logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr) logger.Debug(logMessage)