159 lines
3.1 KiB
Go
159 lines
3.1 KiB
Go
package key
|
|
|
|
import (
|
|
"errors"
|
|
"log"
|
|
"time"
|
|
|
|
ptime "github.com/coreos/pkg/timeutil"
|
|
"github.com/jonboulle/clockwork"
|
|
)
|
|
|
|
var (
|
|
ErrorPrivateKeysExpired = errors.New("private keys have expired")
|
|
)
|
|
|
|
func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator {
|
|
return &PrivateKeyRotator{
|
|
repo: repo,
|
|
ttl: ttl,
|
|
|
|
keep: 2,
|
|
generateKey: GeneratePrivateKey,
|
|
clock: clockwork.NewRealClock(),
|
|
}
|
|
}
|
|
|
|
type PrivateKeyRotator struct {
|
|
repo PrivateKeySetRepo
|
|
generateKey GeneratePrivateKeyFunc
|
|
clock clockwork.Clock
|
|
keep int
|
|
ttl time.Duration
|
|
}
|
|
|
|
func (r *PrivateKeyRotator) expiresAt() time.Time {
|
|
return r.clock.Now().UTC().Add(r.ttl)
|
|
}
|
|
|
|
func (r *PrivateKeyRotator) Healthy() error {
|
|
pks, err := r.privateKeySet()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if r.clock.Now().After(pks.ExpiresAt()) {
|
|
return ErrorPrivateKeysExpired
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
|
|
ks, err := r.repo.Get()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pks, ok := ks.(*PrivateKeySet)
|
|
if !ok {
|
|
return nil, errors.New("unable to cast to PrivateKeySet")
|
|
}
|
|
return pks, nil
|
|
}
|
|
|
|
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
|
|
pks, err := r.privateKeySet()
|
|
if err == ErrorNoKeys {
|
|
return 0, nil
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
now := r.clock.Now()
|
|
|
|
// Ideally, we want to rotate after half the TTL has elapsed.
|
|
idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2)
|
|
|
|
// If we are past the ideal rotation time, rotate immediatly.
|
|
return max(0, idealRotationTime.Sub(now)), nil
|
|
}
|
|
|
|
func max(a, b time.Duration) time.Duration {
|
|
if a > b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (r *PrivateKeyRotator) Run() chan struct{} {
|
|
attempt := func() {
|
|
k, err := r.generateKey()
|
|
if err != nil {
|
|
log.Printf("go-oidc: failed generating signing key: %v", err)
|
|
return
|
|
}
|
|
|
|
exp := r.expiresAt()
|
|
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
|
|
log.Printf("go-oidc: key rotation failed: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
stop := make(chan struct{})
|
|
go func() {
|
|
for {
|
|
var nextRotation time.Duration
|
|
var sleep time.Duration
|
|
var err error
|
|
for {
|
|
if nextRotation, err = r.nextRotation(); err == nil {
|
|
break
|
|
}
|
|
sleep = ptime.ExpBackoff(sleep, time.Minute)
|
|
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
|
|
time.Sleep(sleep)
|
|
}
|
|
|
|
select {
|
|
case <-r.clock.After(nextRotation):
|
|
attempt()
|
|
case <-stop:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return stop
|
|
}
|
|
|
|
func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error {
|
|
ks, err := repo.Get()
|
|
if err != nil && err != ErrorNoKeys {
|
|
return err
|
|
}
|
|
|
|
var keys []*PrivateKey
|
|
if ks != nil {
|
|
pks, ok := ks.(*PrivateKeySet)
|
|
if !ok {
|
|
return errors.New("unable to cast to PrivateKeySet")
|
|
}
|
|
keys = pks.Keys()
|
|
}
|
|
|
|
keys = append([]*PrivateKey{k}, keys...)
|
|
if l := len(keys); l > keep {
|
|
keys = keys[0:keep]
|
|
}
|
|
|
|
nks := PrivateKeySet{
|
|
keys: keys,
|
|
ActiveKeyID: k.ID(),
|
|
expiresAt: exp,
|
|
}
|
|
|
|
return repo.Set(KeySet(&nks))
|
|
}
|