Only calculate basic auth hashes once for concurrent requests
This commit is contained in:
parent
a7502c8700
commit
6f469ee1ec
3 changed files with 73 additions and 10 deletions
2
go.mod
2
go.mod
|
@ -86,6 +86,7 @@ require (
|
|||
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // No tag on the repo.
|
||||
golang.org/x/mod v0.21.0
|
||||
golang.org/x/net v0.29.0
|
||||
golang.org/x/sync v0.8.0
|
||||
golang.org/x/sys v0.25.0
|
||||
golang.org/x/text v0.18.0
|
||||
golang.org/x/time v0.5.0
|
||||
|
@ -343,7 +344,6 @@ require (
|
|||
golang.org/x/arch v0.4.0 // indirect
|
||||
golang.org/x/crypto v0.27.0 // indirect
|
||||
golang.org/x/oauth2 v0.21.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/term v0.24.0 // indirect
|
||||
google.golang.org/api v0.172.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/traefik/traefik/v3/pkg/middlewares/accesslog"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,6 +27,9 @@ type basicAuth struct {
|
|||
headerField string
|
||||
removeHeader bool
|
||||
name string
|
||||
|
||||
checkSecret func(password, secret string) bool
|
||||
singleflightGroup *singleflight.Group
|
||||
}
|
||||
|
||||
// NewBasic creates a basicAuth middleware.
|
||||
|
@ -38,11 +42,13 @@ func NewBasic(ctx context.Context, next http.Handler, authConfig dynamic.BasicAu
|
|||
}
|
||||
|
||||
ba := &basicAuth{
|
||||
next: next,
|
||||
users: users,
|
||||
headerField: authConfig.HeaderField,
|
||||
removeHeader: authConfig.RemoveHeader,
|
||||
name: name,
|
||||
next: next,
|
||||
users: users,
|
||||
headerField: authConfig.HeaderField,
|
||||
removeHeader: authConfig.RemoveHeader,
|
||||
name: name,
|
||||
checkSecret: goauth.CheckSecret,
|
||||
singleflightGroup: new(singleflight.Group),
|
||||
}
|
||||
|
||||
realm := defaultRealm
|
||||
|
@ -64,10 +70,7 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
user, password, ok := req.BasicAuth()
|
||||
if ok {
|
||||
secret := b.auth.Secrets(user, b.auth.Realm)
|
||||
if secret == "" || !goauth.CheckSecret(password, secret) {
|
||||
ok = false
|
||||
}
|
||||
ok = b.checkPassword(user, password)
|
||||
}
|
||||
|
||||
logData := accesslog.GetLogData(req)
|
||||
|
@ -97,6 +100,20 @@ func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
b.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (b *basicAuth) checkPassword(user, password string) bool {
|
||||
secret := b.auth.Secrets(user, b.auth.Realm)
|
||||
if secret == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
key := password + secret
|
||||
match, _, _ := b.singleflightGroup.Do(key, func() (any, error) {
|
||||
return b.checkSecret(password, secret), nil
|
||||
})
|
||||
|
||||
return match.(bool)
|
||||
}
|
||||
|
||||
func (b *basicAuth) secretBasic(user, realm string) string {
|
||||
if secret, ok := b.users[user]; ok {
|
||||
return secret
|
||||
|
|
|
@ -7,7 +7,9 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -167,6 +169,50 @@ func TestBasicAuthHeaderPresent(t *testing.T) {
|
|||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthConcurrentHashOnce(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
auth := dynamic.BasicAuth{
|
||||
Users: []string{"test:$2a$04$.8sTYfcxbSplCtoxt5TdJOgpBYkarKtZYsYfYxQ1edbYRuO1DNi0e"},
|
||||
}
|
||||
|
||||
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
hashCount := 0
|
||||
ba := authMiddleware.(*basicAuth)
|
||||
ba.checkSecret = func(password, secret string) bool {
|
||||
hashCount++
|
||||
// delay to ensure the second request arrives
|
||||
time.Sleep(time.Millisecond)
|
||||
return true
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
for range 2 {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, 1, hashCount)
|
||||
}
|
||||
|
||||
func TestBasicAuthUsersFromFile(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
|
Loading…
Add table
Reference in a new issue