416 lines
10 KiB
Go
416 lines
10 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
phttp "github.com/coreos/go-oidc/http"
|
|
)
|
|
|
|
// ResponseTypesEqual compares two response_type values. If either
|
|
// contains a space, it is treated as an unordered list. For example,
|
|
// comparing "code id_token" and "id_token code" would evaluate to true.
|
|
func ResponseTypesEqual(r1, r2 string) bool {
|
|
if !strings.Contains(r1, " ") || !strings.Contains(r2, " ") {
|
|
// fast route, no split needed
|
|
return r1 == r2
|
|
}
|
|
|
|
// split, sort, and compare
|
|
r1Fields := strings.Fields(r1)
|
|
r2Fields := strings.Fields(r2)
|
|
if len(r1Fields) != len(r2Fields) {
|
|
return false
|
|
}
|
|
sort.Strings(r1Fields)
|
|
sort.Strings(r2Fields)
|
|
for i, r1Field := range r1Fields {
|
|
if r1Field != r2Fields[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
const (
|
|
// OAuth2.0 response types registered by OIDC.
|
|
//
|
|
// See: https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#RegistryContents
|
|
ResponseTypeCode = "code"
|
|
ResponseTypeCodeIDToken = "code id_token"
|
|
ResponseTypeCodeIDTokenToken = "code id_token token"
|
|
ResponseTypeIDToken = "id_token"
|
|
ResponseTypeIDTokenToken = "id_token token"
|
|
ResponseTypeToken = "token"
|
|
ResponseTypeNone = "none"
|
|
)
|
|
|
|
const (
|
|
GrantTypeAuthCode = "authorization_code"
|
|
GrantTypeClientCreds = "client_credentials"
|
|
GrantTypeUserCreds = "password"
|
|
GrantTypeImplicit = "implicit"
|
|
GrantTypeRefreshToken = "refresh_token"
|
|
|
|
AuthMethodClientSecretPost = "client_secret_post"
|
|
AuthMethodClientSecretBasic = "client_secret_basic"
|
|
AuthMethodClientSecretJWT = "client_secret_jwt"
|
|
AuthMethodPrivateKeyJWT = "private_key_jwt"
|
|
)
|
|
|
|
type Config struct {
|
|
Credentials ClientCredentials
|
|
Scope []string
|
|
RedirectURL string
|
|
AuthURL string
|
|
TokenURL string
|
|
|
|
// Must be one of the AuthMethodXXX methods above. Right now, only
|
|
// AuthMethodClientSecretPost and AuthMethodClientSecretBasic are supported.
|
|
AuthMethod string
|
|
}
|
|
|
|
type Client struct {
|
|
hc phttp.Client
|
|
creds ClientCredentials
|
|
scope []string
|
|
authURL *url.URL
|
|
redirectURL *url.URL
|
|
tokenURL *url.URL
|
|
authMethod string
|
|
}
|
|
|
|
type ClientCredentials struct {
|
|
ID string
|
|
Secret string
|
|
}
|
|
|
|
func NewClient(hc phttp.Client, cfg Config) (c *Client, err error) {
|
|
if len(cfg.Credentials.ID) == 0 {
|
|
err = errors.New("missing client id")
|
|
return
|
|
}
|
|
|
|
if len(cfg.Credentials.Secret) == 0 {
|
|
err = errors.New("missing client secret")
|
|
return
|
|
}
|
|
|
|
if cfg.AuthMethod == "" {
|
|
cfg.AuthMethod = AuthMethodClientSecretBasic
|
|
} else if cfg.AuthMethod != AuthMethodClientSecretPost && cfg.AuthMethod != AuthMethodClientSecretBasic {
|
|
err = fmt.Errorf("auth method %q is not supported", cfg.AuthMethod)
|
|
return
|
|
}
|
|
|
|
au, err := phttp.ParseNonEmptyURL(cfg.AuthURL)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
tu, err := phttp.ParseNonEmptyURL(cfg.TokenURL)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Allow empty redirect URL in the case where the client
|
|
// only needs to verify a given token.
|
|
ru, err := url.Parse(cfg.RedirectURL)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
c = &Client{
|
|
creds: cfg.Credentials,
|
|
scope: cfg.Scope,
|
|
redirectURL: ru,
|
|
authURL: au,
|
|
tokenURL: tu,
|
|
hc: hc,
|
|
authMethod: cfg.AuthMethod,
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Return the embedded HTTP client
|
|
func (c *Client) HttpClient() phttp.Client {
|
|
return c.hc
|
|
}
|
|
|
|
// Generate the url for initial redirect to oauth provider.
|
|
func (c *Client) AuthCodeURL(state, accessType, prompt string) string {
|
|
v := c.commonURLValues()
|
|
v.Set("state", state)
|
|
if strings.ToLower(accessType) == "offline" {
|
|
v.Set("access_type", "offline")
|
|
}
|
|
|
|
if prompt != "" {
|
|
v.Set("prompt", prompt)
|
|
}
|
|
v.Set("response_type", "code")
|
|
|
|
q := v.Encode()
|
|
u := *c.authURL
|
|
if u.RawQuery == "" {
|
|
u.RawQuery = q
|
|
} else {
|
|
u.RawQuery += "&" + q
|
|
}
|
|
return u.String()
|
|
}
|
|
|
|
func (c *Client) commonURLValues() url.Values {
|
|
return url.Values{
|
|
"redirect_uri": {c.redirectURL.String()},
|
|
"scope": {strings.Join(c.scope, " ")},
|
|
"client_id": {c.creds.ID},
|
|
}
|
|
}
|
|
|
|
func (c *Client) newAuthenticatedRequest(urlToken string, values url.Values) (*http.Request, error) {
|
|
var req *http.Request
|
|
var err error
|
|
switch c.authMethod {
|
|
case AuthMethodClientSecretPost:
|
|
values.Set("client_secret", c.creds.Secret)
|
|
req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case AuthMethodClientSecretBasic:
|
|
req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
encodedID := url.QueryEscape(c.creds.ID)
|
|
encodedSecret := url.QueryEscape(c.creds.Secret)
|
|
req.SetBasicAuth(encodedID, encodedSecret)
|
|
default:
|
|
panic("misconfigured client: auth method not supported")
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
return req, nil
|
|
|
|
}
|
|
|
|
// ClientCredsToken posts the client id and secret to obtain a token scoped to the OAuth2 client via the "client_credentials" grant type.
|
|
// May not be supported by all OAuth2 servers.
|
|
func (c *Client) ClientCredsToken(scope []string) (result TokenResponse, err error) {
|
|
v := url.Values{
|
|
"scope": {strings.Join(scope, " ")},
|
|
"grant_type": {GrantTypeClientCreds},
|
|
}
|
|
|
|
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
resp, err := c.hc.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return parseTokenResponse(resp)
|
|
}
|
|
|
|
// UserCredsToken posts the username and password to obtain a token scoped to the OAuth2 client via the "password" grant_type
|
|
// May not be supported by all OAuth2 servers.
|
|
func (c *Client) UserCredsToken(username, password string) (result TokenResponse, err error) {
|
|
v := url.Values{
|
|
"scope": {strings.Join(c.scope, " ")},
|
|
"grant_type": {GrantTypeUserCreds},
|
|
"username": {username},
|
|
"password": {password},
|
|
}
|
|
|
|
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
resp, err := c.hc.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return parseTokenResponse(resp)
|
|
}
|
|
|
|
// RequestToken requests a token from the Token Endpoint with the specified grantType.
|
|
// If 'grantType' == GrantTypeAuthCode, then 'value' should be the authorization code.
|
|
// If 'grantType' == GrantTypeRefreshToken, then 'value' should be the refresh token.
|
|
func (c *Client) RequestToken(grantType, value string) (result TokenResponse, err error) {
|
|
v := c.commonURLValues()
|
|
|
|
v.Set("grant_type", grantType)
|
|
v.Set("client_secret", c.creds.Secret)
|
|
switch grantType {
|
|
case GrantTypeAuthCode:
|
|
v.Set("code", value)
|
|
case GrantTypeRefreshToken:
|
|
v.Set("refresh_token", value)
|
|
default:
|
|
err = fmt.Errorf("unsupported grant_type: %v", grantType)
|
|
return
|
|
}
|
|
|
|
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
resp, err := c.hc.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return parseTokenResponse(resp)
|
|
}
|
|
|
|
func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return
|
|
}
|
|
badStatusCode := resp.StatusCode < 200 || resp.StatusCode > 299
|
|
|
|
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
result = TokenResponse{
|
|
RawBody: body,
|
|
}
|
|
|
|
newError := func(typ, desc, state string) error {
|
|
if typ == "" {
|
|
return fmt.Errorf("unrecognized error %s", body)
|
|
}
|
|
return &Error{typ, desc, state}
|
|
}
|
|
|
|
if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" {
|
|
var vals url.Values
|
|
vals, err = url.ParseQuery(string(body))
|
|
if err != nil {
|
|
return
|
|
}
|
|
if error := vals.Get("error"); error != "" || badStatusCode {
|
|
err = newError(error, vals.Get("error_description"), vals.Get("state"))
|
|
return
|
|
}
|
|
e := vals.Get("expires_in")
|
|
if e == "" {
|
|
e = vals.Get("expires")
|
|
}
|
|
if e != "" {
|
|
result.Expires, err = strconv.Atoi(e)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
result.AccessToken = vals.Get("access_token")
|
|
result.TokenType = vals.Get("token_type")
|
|
result.IDToken = vals.Get("id_token")
|
|
result.RefreshToken = vals.Get("refresh_token")
|
|
result.Scope = vals.Get("scope")
|
|
} else {
|
|
var r struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
IDToken string `json:"id_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
Scope string `json:"scope"`
|
|
State string `json:"state"`
|
|
ExpiresIn json.Number `json:"expires_in"` // Azure AD returns string
|
|
Expires int `json:"expires"`
|
|
Error string `json:"error"`
|
|
Desc string `json:"error_description"`
|
|
}
|
|
if err = json.Unmarshal(body, &r); err != nil {
|
|
return
|
|
}
|
|
if r.Error != "" || badStatusCode {
|
|
err = newError(r.Error, r.Desc, r.State)
|
|
return
|
|
}
|
|
result.AccessToken = r.AccessToken
|
|
result.TokenType = r.TokenType
|
|
result.IDToken = r.IDToken
|
|
result.RefreshToken = r.RefreshToken
|
|
result.Scope = r.Scope
|
|
if expiresIn, err := r.ExpiresIn.Int64(); err != nil {
|
|
result.Expires = r.Expires
|
|
} else {
|
|
result.Expires = int(expiresIn)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
type TokenResponse struct {
|
|
AccessToken string
|
|
TokenType string
|
|
Expires int
|
|
IDToken string
|
|
RefreshToken string // OPTIONAL.
|
|
Scope string // OPTIONAL, if identical to the scope requested by the client, otherwise, REQUIRED.
|
|
RawBody []byte // In case callers need some other non-standard info from the token response
|
|
}
|
|
|
|
type AuthCodeRequest struct {
|
|
ResponseType string
|
|
ClientID string
|
|
RedirectURL *url.URL
|
|
Scope []string
|
|
State string
|
|
}
|
|
|
|
func ParseAuthCodeRequest(q url.Values) (AuthCodeRequest, error) {
|
|
acr := AuthCodeRequest{
|
|
ResponseType: q.Get("response_type"),
|
|
ClientID: q.Get("client_id"),
|
|
State: q.Get("state"),
|
|
Scope: make([]string, 0),
|
|
}
|
|
|
|
qs := strings.TrimSpace(q.Get("scope"))
|
|
if qs != "" {
|
|
acr.Scope = strings.Split(qs, " ")
|
|
}
|
|
|
|
err := func() error {
|
|
if acr.ClientID == "" {
|
|
return NewError(ErrorInvalidRequest)
|
|
}
|
|
|
|
redirectURL := q.Get("redirect_uri")
|
|
if redirectURL != "" {
|
|
ru, err := url.Parse(redirectURL)
|
|
if err != nil {
|
|
return NewError(ErrorInvalidRequest)
|
|
}
|
|
acr.RedirectURL = ru
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
|
|
return acr, err
|
|
}
|