Fix NTLM and Kerberos
This commit is contained in:
parent
8f9ad16f54
commit
e11ff98608
4 changed files with 161 additions and 3 deletions
|
@ -27,6 +27,7 @@ import (
|
||||||
"github.com/traefik/traefik/v2/pkg/safe"
|
"github.com/traefik/traefik/v2/pkg/safe"
|
||||||
"github.com/traefik/traefik/v2/pkg/server/router"
|
"github.com/traefik/traefik/v2/pkg/server/router"
|
||||||
tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp"
|
tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp"
|
||||||
|
"github.com/traefik/traefik/v2/pkg/server/service"
|
||||||
"github.com/traefik/traefik/v2/pkg/tcp"
|
"github.com/traefik/traefik/v2/pkg/tcp"
|
||||||
"github.com/traefik/traefik/v2/pkg/types"
|
"github.com/traefik/traefik/v2/pkg/types"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
@ -613,6 +614,16 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prevConnContext := serverHTTP.ConnContext
|
||||||
|
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
|
||||||
|
// This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM.
|
||||||
|
ctx = service.AddTransportOnContext(ctx)
|
||||||
|
if prevConnContext != nil {
|
||||||
|
return prevConnContext(ctx, c)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
// ConfigureServer configures HTTP/2 with the MaxConcurrentStreams option for the given server.
|
// ConfigureServer configures HTTP/2 with the MaxConcurrentStreams option for the given server.
|
||||||
// Also keeping behavior the same as
|
// Also keeping behavior the same as
|
||||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/http/server.go;l=3262
|
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/http/server.go;l=3262
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -8,6 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -149,10 +151,71 @@ func createRoundTripper(cfg *dynamic.ServersTransport) (http.RoundTripper, error
|
||||||
|
|
||||||
// Return directly HTTP/1.1 transport when HTTP/2 is disabled
|
// Return directly HTTP/1.1 transport when HTTP/2 is disabled
|
||||||
if cfg.DisableHTTP2 {
|
if cfg.DisableHTTP2 {
|
||||||
return transport, nil
|
return &KerberosRoundTripper{
|
||||||
|
OriginalRoundTripper: transport,
|
||||||
|
new: func() http.RoundTripper {
|
||||||
|
return transport.Clone()
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return newSmartRoundTripper(transport, cfg.ForwardingTimeouts)
|
rt, err := newSmartRoundTripper(transport, cfg.ForwardingTimeouts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &KerberosRoundTripper{
|
||||||
|
OriginalRoundTripper: rt,
|
||||||
|
new: func() http.RoundTripper {
|
||||||
|
return rt.Clone()
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type KerberosRoundTripper struct {
|
||||||
|
new func() http.RoundTripper
|
||||||
|
OriginalRoundTripper http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
type stickyRoundTripper struct {
|
||||||
|
RoundTripper http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
type transportKeyType string
|
||||||
|
|
||||||
|
var transportKey transportKeyType = "transport"
|
||||||
|
|
||||||
|
func AddTransportOnContext(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, transportKey, &stickyRoundTripper{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *KerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
value, ok := request.Context().Value(transportKey).(*stickyRoundTripper)
|
||||||
|
if !ok {
|
||||||
|
return k.OriginalRoundTripper.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.RoundTripper != nil {
|
||||||
|
return value.RoundTripper.RoundTrip(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := k.OriginalRoundTripper.RoundTrip(request)
|
||||||
|
|
||||||
|
// If we found that we are authenticating with Kerberos (Negotiate) or NTLM.
|
||||||
|
// We put a dedicated roundTripper in the ConnContext.
|
||||||
|
// This will stick the next calls to the same connection with the backend.
|
||||||
|
if err == nil && containsNTLMorNegotiate(resp.Header.Values("WWW-Authenticate")) {
|
||||||
|
value.RoundTripper = k.new()
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsNTLMorNegotiate(h []string) bool {
|
||||||
|
for _, s := range h {
|
||||||
|
if strings.HasPrefix(s, "NTLM") || strings.HasPrefix(s, "Negotiate") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRootCACertPool(rootCAs []traefiktls.FileOrContent) *x509.CertPool {
|
func createRootCACertPool(rootCAs []traefiktls.FileOrContent) *x509.CertPool {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net"
|
"net"
|
||||||
|
@ -293,3 +294,80 @@ func TestDisableHTTP2(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type roundTripperFn func(req *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (r roundTripperFn) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return r(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKerberosRoundTripper(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
|
||||||
|
originalRoundTripperHeaders map[string][]string
|
||||||
|
|
||||||
|
expectedStatusCode []int
|
||||||
|
expectedDedicatedCount int
|
||||||
|
expectedOriginalCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "without special header",
|
||||||
|
expectedStatusCode: []int{http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized},
|
||||||
|
expectedOriginalCount: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "with Negotiate (Kerberos)",
|
||||||
|
originalRoundTripperHeaders: map[string][]string{"Www-Authenticate": {"Negotiate"}},
|
||||||
|
expectedStatusCode: []int{http.StatusUnauthorized, http.StatusOK, http.StatusOK},
|
||||||
|
expectedOriginalCount: 1,
|
||||||
|
expectedDedicatedCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "with NTLM",
|
||||||
|
originalRoundTripperHeaders: map[string][]string{"Www-Authenticate": {"NTLM"}},
|
||||||
|
expectedStatusCode: []int{http.StatusUnauthorized, http.StatusOK, http.StatusOK},
|
||||||
|
expectedOriginalCount: 1,
|
||||||
|
expectedDedicatedCount: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
origCount := 0
|
||||||
|
dedicatedCount := 0
|
||||||
|
rt := KerberosRoundTripper{
|
||||||
|
new: func() http.RoundTripper {
|
||||||
|
return roundTripperFn(func(req *http.Request) (*http.Response, error) {
|
||||||
|
dedicatedCount++
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
},
|
||||||
|
OriginalRoundTripper: roundTripperFn(func(req *http.Request) (*http.Response, error) {
|
||||||
|
origCount++
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusUnauthorized,
|
||||||
|
Header: test.originalRoundTripperHeaders,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := AddTransportOnContext(context.Background())
|
||||||
|
for _, expected := range test.expectedStatusCode {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1", http.NoBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, expected, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, test.expectedOriginalCount, origCount)
|
||||||
|
require.Equal(t, test.expectedDedicatedCount, dedicatedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (http.RoundTripper, error) {
|
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) {
|
||||||
transportHTTP1 := transport.Clone()
|
transportHTTP1 := transport.Clone()
|
||||||
|
|
||||||
transportHTTP2, err := http2.ConfigureTransports(transport)
|
transportHTTP2, err := http2.ConfigureTransports(transport)
|
||||||
|
@ -53,6 +53,12 @@ type smartRoundTripper struct {
|
||||||
http *http.Transport
|
http *http.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *smartRoundTripper) Clone() http.RoundTripper {
|
||||||
|
h := m.http.Clone()
|
||||||
|
h2 := m.http2.Clone()
|
||||||
|
return &smartRoundTripper{http: h, http2: h2}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
// If we have a connection upgrade, we don't use HTTP/2
|
// If we have a connection upgrade, we don't use HTTP/2
|
||||||
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
|
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
|
||||||
|
|
Loading…
Reference in a new issue