Only calculate basic auth hashes once for concurrent requests
This commit is contained in:
parent
a7502c8700
commit
6f469ee1ec
3 changed files with 73 additions and 10 deletions
2
go.mod
2
go.mod
|
@ -86,6 +86,7 @@ require (
|
||||||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // No tag on the repo.
|
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // No tag on the repo.
|
||||||
golang.org/x/mod v0.21.0
|
golang.org/x/mod v0.21.0
|
||||||
golang.org/x/net v0.29.0
|
golang.org/x/net v0.29.0
|
||||||
|
golang.org/x/sync v0.8.0
|
||||||
golang.org/x/sys v0.25.0
|
golang.org/x/sys v0.25.0
|
||||||
golang.org/x/text v0.18.0
|
golang.org/x/text v0.18.0
|
||||||
golang.org/x/time v0.5.0
|
golang.org/x/time v0.5.0
|
||||||
|
@ -343,7 +344,6 @@ require (
|
||||||
golang.org/x/arch v0.4.0 // indirect
|
golang.org/x/arch v0.4.0 // indirect
|
||||||
golang.org/x/crypto v0.27.0 // indirect
|
golang.org/x/crypto v0.27.0 // indirect
|
||||||
golang.org/x/oauth2 v0.21.0 // indirect
|
golang.org/x/oauth2 v0.21.0 // indirect
|
||||||
golang.org/x/sync v0.8.0 // indirect
|
|
||||||
golang.org/x/term v0.24.0 // indirect
|
golang.org/x/term v0.24.0 // indirect
|
||||||
google.golang.org/api v0.172.0 // indirect
|
google.golang.org/api v0.172.0 // indirect
|
||||||
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
|
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares/accesslog"
|
"github.com/traefik/traefik/v3/pkg/middlewares/accesslog"
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -26,6 +27,9 @@ type basicAuth struct {
|
||||||
headerField string
|
headerField string
|
||||||
removeHeader bool
|
removeHeader bool
|
||||||
name string
|
name string
|
||||||
|
|
||||||
|
checkSecret func(password, secret string) bool
|
||||||
|
singleflightGroup *singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBasic creates a basicAuth middleware.
|
// NewBasic creates a basicAuth middleware.
|
||||||
|
@ -38,11 +42,13 @@ func NewBasic(ctx context.Context, next http.Handler, authConfig dynamic.BasicAu
|
||||||
}
|
}
|
||||||
|
|
||||||
ba := &basicAuth{
|
ba := &basicAuth{
|
||||||
next: next,
|
next: next,
|
||||||
users: users,
|
users: users,
|
||||||
headerField: authConfig.HeaderField,
|
headerField: authConfig.HeaderField,
|
||||||
removeHeader: authConfig.RemoveHeader,
|
removeHeader: authConfig.RemoveHeader,
|
||||||
name: name,
|
name: name,
|
||||||
|
checkSecret: goauth.CheckSecret,
|
||||||
|
singleflightGroup: new(singleflight.Group),
|
||||||
}
|
}
|
||||||
|
|
||||||
realm := defaultRealm
|
realm := defaultRealm
|
||||||
|
@ -64,10 +70,7 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
user, password, ok := req.BasicAuth()
|
user, password, ok := req.BasicAuth()
|
||||||
if ok {
|
if ok {
|
||||||
secret := b.auth.Secrets(user, b.auth.Realm)
|
ok = b.checkPassword(user, password)
|
||||||
if secret == "" || !goauth.CheckSecret(password, secret) {
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logData := accesslog.GetLogData(req)
|
logData := accesslog.GetLogData(req)
|
||||||
|
@ -97,6 +100,20 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
b.next.ServeHTTP(rw, req)
|
b.next.ServeHTTP(rw, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *basicAuth) checkPassword(user, password string) bool {
|
||||||
|
secret := b.auth.Secrets(user, b.auth.Realm)
|
||||||
|
if secret == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := password + secret
|
||||||
|
match, _, _ := b.singleflightGroup.Do(key, func() (any, error) {
|
||||||
|
return b.checkSecret(password, secret), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return match.(bool)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *basicAuth) secretBasic(user, realm string) string {
|
func (b *basicAuth) secretBasic(user, realm string) string {
|
||||||
if secret, ok := b.users[user]; ok {
|
if secret, ok := b.users[user]; ok {
|
||||||
return secret
|
return secret
|
||||||
|
|
|
@ -7,7 +7,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -167,6 +169,50 @@ func TestBasicAuthHeaderPresent(t *testing.T) {
|
||||||
assert.Equal(t, "traefik\n", string(body))
|
assert.Equal(t, "traefik\n", string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBasicAuthConcurrentHashOnce(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "traefik")
|
||||||
|
})
|
||||||
|
auth := dynamic.BasicAuth{
|
||||||
|
Users: []string{"test:$2a$04$.8sTYfcxbSplCtoxt5TdJOgpBYkarKtZYsYfYxQ1edbYRuO1DNi0e"},
|
||||||
|
}
|
||||||
|
|
||||||
|
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
hashCount := 0
|
||||||
|
ba := authMiddleware.(*basicAuth)
|
||||||
|
ba.checkSecret = func(password, secret string) bool {
|
||||||
|
hashCount++
|
||||||
|
// delay to ensure the second request arrives
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := httptest.NewServer(authMiddleware)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
for range 2 {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||||
|
req.SetBasicAuth("test", "test")
|
||||||
|
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, 1, hashCount)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBasicAuthUsersFromFile(t *testing.T) {
|
func TestBasicAuthUsersFromFile(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|
Loading…
Reference in a new issue