201 lines
5.1 KiB
Go
201 lines
5.1 KiB
Go
|
package oidc
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pquerna/cachecontrol"
|
||
|
"golang.org/x/net/context"
|
||
|
"golang.org/x/net/context/ctxhttp"
|
||
|
jose "gopkg.in/square/go-jose.v2"
|
||
|
)
|
||
|
|
||
|
// keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
|
||
|
// server.
|
||
|
//
|
||
|
// When keys expire, they are valid for this amount of time after.
|
||
|
//
|
||
|
// If the keys have not expired, and an ID Token claims it was signed by a key not in
|
||
|
// the cache, if and only if the keys expire in this amount of time, the keys will be
|
||
|
// updated.
|
||
|
const keysExpiryDelta = 30 * time.Second
|
||
|
|
||
|
func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
|
||
|
if now == nil {
|
||
|
now = time.Now
|
||
|
}
|
||
|
return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
|
||
|
}
|
||
|
|
||
|
type remoteKeySet struct {
|
||
|
jwksURL string
|
||
|
ctx context.Context
|
||
|
now func() time.Time
|
||
|
|
||
|
// guard all other fields
|
||
|
mu sync.Mutex
|
||
|
|
||
|
// inflightCtx suppresses parallel execution of updateKeys and allows
|
||
|
// multiple goroutines to wait for its result.
|
||
|
// Its Err() method returns any errors encountered during updateKeys.
|
||
|
//
|
||
|
// If nil, there is no inflight updateKeys request.
|
||
|
inflightCtx *inflight
|
||
|
|
||
|
// A set of cached keys and their expiry.
|
||
|
cachedKeys []jose.JSONWebKey
|
||
|
expiry time.Time
|
||
|
}
|
||
|
|
||
|
// inflight is used to wait on some in-flight request from multiple goroutines
|
||
|
type inflight struct {
|
||
|
done chan struct{}
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
// Done returns a channel that is closed when the inflight request finishes.
|
||
|
func (i *inflight) Done() <-chan struct{} {
|
||
|
return i.done
|
||
|
}
|
||
|
|
||
|
// Err returns any error encountered during request execution. May be nil.
|
||
|
func (i *inflight) Err() error {
|
||
|
return i.err
|
||
|
}
|
||
|
|
||
|
// Cancel signals completion of the inflight request with error err.
|
||
|
// Must be called only once for particular inflight instance.
|
||
|
func (i *inflight) Cancel(err error) {
|
||
|
i.err = err
|
||
|
close(i.done)
|
||
|
}
|
||
|
|
||
|
func (r *remoteKeySet) keysWithIDFromCache(keyIDs []string) ([]jose.JSONWebKey, bool) {
|
||
|
r.mu.Lock()
|
||
|
keys, expiry := r.cachedKeys, r.expiry
|
||
|
r.mu.Unlock()
|
||
|
|
||
|
// Have the keys expired?
|
||
|
if expiry.Add(keysExpiryDelta).Before(r.now()) {
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
var signingKeys []jose.JSONWebKey
|
||
|
for _, key := range keys {
|
||
|
if contains(keyIDs, key.KeyID) {
|
||
|
signingKeys = append(signingKeys, key)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(signingKeys) == 0 {
|
||
|
// Are the keys about to expire?
|
||
|
if r.now().Add(keysExpiryDelta).After(expiry) {
|
||
|
return nil, false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return signingKeys, true
|
||
|
}
|
||
|
func (r *remoteKeySet) keysWithID(ctx context.Context, keyIDs []string) ([]jose.JSONWebKey, error) {
|
||
|
keys, ok := r.keysWithIDFromCache(keyIDs)
|
||
|
if ok {
|
||
|
return keys, nil
|
||
|
}
|
||
|
|
||
|
var inflightCtx *inflight
|
||
|
func() {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
|
||
|
// If there's not a current inflight request, create one.
|
||
|
if r.inflightCtx == nil {
|
||
|
inflightCtx := &inflight{make(chan struct{}), nil}
|
||
|
r.inflightCtx = inflightCtx
|
||
|
|
||
|
go func() {
|
||
|
// TODO(ericchiang): Upstream Kubernetes request that we recover every time
|
||
|
// we spawn a goroutine, because panics in a goroutine will bring down the
|
||
|
// entire program. There's no way to recover from another goroutine's panic.
|
||
|
//
|
||
|
// Most users actually want to let the panic propagate and bring down the
|
||
|
// program because it implies some unrecoverable state.
|
||
|
//
|
||
|
// Add a context key to allow the recover behavior.
|
||
|
//
|
||
|
// See: https://github.com/coreos/go-oidc/issues/89
|
||
|
|
||
|
// Sync keys and close inflightCtx when that's done.
|
||
|
// Use the remoteKeySet's context instead of the requests context
|
||
|
// because a re-sync is unique to the keys set and will span multiple
|
||
|
// requests.
|
||
|
inflightCtx.Cancel(r.updateKeys(r.ctx))
|
||
|
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
r.inflightCtx = nil
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
inflightCtx = r.inflightCtx
|
||
|
}()
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return nil, ctx.Err()
|
||
|
case <-inflightCtx.Done():
|
||
|
if err := inflightCtx.Err(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Since we've just updated keys, we don't care about the cache miss.
|
||
|
keys, _ = r.keysWithIDFromCache(keyIDs)
|
||
|
return keys, nil
|
||
|
}
|
||
|
|
||
|
func (r *remoteKeySet) updateKeys(ctx context.Context) error {
|
||
|
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("oidc: can't create request: %v", err)
|
||
|
}
|
||
|
|
||
|
resp, err := ctxhttp.Do(ctx, clientFromContext(ctx), req)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("oidc: get keys failed %v", err)
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
body, err := ioutil.ReadAll(resp.Body)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("oidc: read response body: %v", err)
|
||
|
}
|
||
|
if resp.StatusCode != http.StatusOK {
|
||
|
return fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
|
||
|
}
|
||
|
|
||
|
var keySet jose.JSONWebKeySet
|
||
|
if err := json.Unmarshal(body, &keySet); err != nil {
|
||
|
return fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
|
||
|
}
|
||
|
|
||
|
// If the server doesn't provide cache control headers, assume the
|
||
|
// keys expire immediately.
|
||
|
expiry := r.now()
|
||
|
|
||
|
_, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
|
||
|
if err == nil && e.After(expiry) {
|
||
|
expiry = e
|
||
|
}
|
||
|
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
r.cachedKeys = keySet.Keys
|
||
|
r.expiry = expiry
|
||
|
|
||
|
return nil
|
||
|
}
|