diff --git a/go.mod b/go.mod index 932266a29..50f0f0998 100644 --- a/go.mod +++ b/go.mod @@ -86,6 +86,7 @@ require ( golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // No tag on the repo. golang.org/x/mod v0.21.0 golang.org/x/net v0.29.0 + golang.org/x/sync v0.8.0 golang.org/x/sys v0.25.0 golang.org/x/text v0.18.0 golang.org/x/time v0.5.0 @@ -343,7 +344,6 @@ require ( golang.org/x/arch v0.4.0 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/term v0.24.0 // indirect google.golang.org/api v0.172.0 // indirect google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect diff --git a/pkg/middlewares/auth/basic_auth.go b/pkg/middlewares/auth/basic_auth.go index 0ec8f1e28..e6c175bcb 100644 --- a/pkg/middlewares/auth/basic_auth.go +++ b/pkg/middlewares/auth/basic_auth.go @@ -13,6 +13,7 @@ import ( "github.com/traefik/traefik/v3/pkg/middlewares/accesslog" "github.com/traefik/traefik/v3/pkg/middlewares/observability" "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/singleflight" ) const ( @@ -26,6 +27,9 @@ type basicAuth struct { headerField string removeHeader bool name string + + checkSecret func(password, secret string) bool + singleflightGroup *singleflight.Group } // NewBasic creates a basicAuth middleware. @@ -38,11 +42,13 @@ func NewBasic(ctx context.Context, next http.Handler, authConfig dynamic.BasicAu } ba := &basicAuth{ - next: next, - users: users, - headerField: authConfig.HeaderField, - removeHeader: authConfig.RemoveHeader, - name: name, + next: next, + users: users, + headerField: authConfig.HeaderField, + removeHeader: authConfig.RemoveHeader, + name: name, + checkSecret: goauth.CheckSecret, + singleflightGroup: new(singleflight.Group), } realm := defaultRealm @@ -64,10 +70,7 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { user, password, ok := req.BasicAuth() if ok { - secret := b.auth.Secrets(user, b.auth.Realm) - if secret == "" || !goauth.CheckSecret(password, secret) { - ok = false - } + ok = b.checkPassword(user, password) } logData := accesslog.GetLogData(req) @@ -97,6 +100,20 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { b.next.ServeHTTP(rw, req) } +func (b *basicAuth) checkPassword(user, password string) bool { + secret := b.auth.Secrets(user, b.auth.Realm) + if secret == "" { + return false + } + + key := password + secret + match, _, _ := b.singleflightGroup.Do(key, func() (any, error) { + return b.checkSecret(password, secret), nil + }) + + return match.(bool) +} + func (b *basicAuth) secretBasic(user, realm string) string { if secret, ok := b.users[user]; ok { return secret diff --git a/pkg/middlewares/auth/basic_auth_test.go b/pkg/middlewares/auth/basic_auth_test.go index 6198988f6..6a59f0111 100644 --- a/pkg/middlewares/auth/basic_auth_test.go +++ b/pkg/middlewares/auth/basic_auth_test.go @@ -7,7 +7,9 @@ import ( "net/http" "net/http/httptest" "os" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -167,6 +169,50 @@ func TestBasicAuthHeaderPresent(t *testing.T) { assert.Equal(t, "traefik\n", string(body)) } +func TestBasicAuthConcurrentHashOnce(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "traefik") + }) + auth := dynamic.BasicAuth{ + Users: []string{"test:$2a$04$.8sTYfcxbSplCtoxt5TdJOgpBYkarKtZYsYfYxQ1edbYRuO1DNi0e"}, + } + + authMiddleware, err := NewBasic(context.Background(), next, auth, "authName") + require.NoError(t, err) + + hashCount := 0 + ba := authMiddleware.(*basicAuth) + ba.checkSecret = func(password, secret string) bool { + hashCount++ + // delay to ensure the second request arrives + time.Sleep(time.Millisecond) + return true + } + + ts := httptest.NewServer(authMiddleware) + defer ts.Close() + + var wg sync.WaitGroup + wg.Add(2) + + for range 2 { + go func() { + defer wg.Done() + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + req.SetBasicAuth("test", "test") + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal") + }() + } + + wg.Wait() + assert.Equal(t, 1, hashCount) +} + func TestBasicAuthUsersFromFile(t *testing.T) { testCases := []struct { desc string