From 090677e090e0a252e04fdc95765595efafe40aa6 Mon Sep 17 00:00:00 2001 From: Konrad Holowinski Date: Mon, 24 Oct 2022 13:01:10 +0200 Subject: [PATCH 01/16] rebased to github.com/cloudentity/oauth2 --- amazon/amazon.go | 2 +- authhandler/authhandler.go | 2 +- authhandler/authhandler_test.go | 2 +- bitbucket/bitbucket.go | 2 +- cern/cern.go | 4 ++-- clientcredentials/clientcredentials.go | 6 ++--- clientcredentials/clientcredentials_test.go | 2 +- endpoints/endpoints.go | 2 +- endpoints/endpoints_test.go | 2 +- example_test.go | 2 +- facebook/facebook.go | 4 ++-- fitbit/fitbit.go | 4 ++-- foursquare/foursquare.go | 4 ++-- github/github.go | 4 ++-- gitlab/gitlab.go | 4 ++-- go.mod | 2 +- google/appengine.go | 2 +- google/appengine_gen1.go | 2 +- google/appengine_gen2_flex.go | 2 +- google/default.go | 4 ++-- google/doc.go | 4 ++-- google/downscope/downscoping.go | 2 +- google/downscope/downscoping_test.go | 2 +- google/downscope/tokenbroker_test.go | 6 ++--- google/error.go | 2 +- google/error_test.go | 2 +- google/example_test.go | 6 ++--- google/google.go | 6 ++--- google/internal/externalaccount/aws.go | 2 +- .../externalaccount/basecredentials.go | 2 +- .../externalaccount/basecredentials_test.go | 2 +- google/internal/externalaccount/clientauth.go | 2 +- .../externalaccount/clientauth_test.go | 2 +- .../internal/externalaccount/impersonate.go | 2 +- .../internal/externalaccount/sts_exchange.go | 2 +- .../externalaccount/sts_exchange_test.go | 2 +- .../internal/externalaccount/urlcredsource.go | 2 +- google/jwt.go | 6 ++--- google/jwt_test.go | 2 +- google/sdk.go | 2 +- heroku/heroku.go | 4 ++-- hipchat/hipchat.go | 6 ++--- instagram/instagram.go | 4 ++-- internal/oauth2.go | 24 +++++++++++++++++++ internal/token.go | 2 +- jira/jira.go | 2 +- jira/jira_test.go | 4 ++-- jws/jws.go | 4 ++-- jwt/example_test.go | 2 +- jwt/jwt.go | 6 ++--- jwt/jwt_test.go | 4 ++-- kakao/kakao.go | 4 ++-- linkedin/linkedin.go | 4 ++-- mailchimp/mailchimp.go | 4 ++-- mailru/mailru.go | 4 ++-- mediamath/mediamath.go | 4 ++-- microsoft/microsoft.go | 4 ++-- nokiahealth/nokiahealth.go | 2 +- oauth2.go | 6 ++--- oauth2_test.go | 2 +- odnoklassniki/odnoklassniki.go | 4 ++-- paypal/paypal.go | 4 ++-- slack/slack.go | 4 ++-- spotify/spotify.go | 4 ++-- stackoverflow/stackoverflow.go | 4 ++-- token.go | 2 +- transport.go | 2 +- twitch/twitch.go | 4 ++-- uber/uber.go | 4 ++-- vk/vk.go | 4 ++-- yahoo/yahoo.go | 4 ++-- yandex/yandex.go | 4 ++-- 72 files changed, 140 insertions(+), 116 deletions(-) diff --git a/amazon/amazon.go b/amazon/amazon.go index d21da11af..18e254e1f 100644 --- a/amazon/amazon.go +++ b/amazon/amazon.go @@ -6,7 +6,7 @@ package amazon import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Amazon's OAuth 2.0 endpoint. diff --git a/authhandler/authhandler.go b/authhandler/authhandler.go index 9bc6cd7bc..e60255ec9 100644 --- a/authhandler/authhandler.go +++ b/authhandler/authhandler.go @@ -10,7 +10,7 @@ import ( "context" "errors" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) const ( diff --git a/authhandler/authhandler_test.go b/authhandler/authhandler_test.go index ad1980492..365c51be7 100644 --- a/authhandler/authhandler_test.go +++ b/authhandler/authhandler_test.go @@ -11,7 +11,7 @@ import ( "net/http/httptest" "testing" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) func TestTokenExchange_Success(t *testing.T) { diff --git a/bitbucket/bitbucket.go b/bitbucket/bitbucket.go index 44af1f1a9..401c1ccb3 100644 --- a/bitbucket/bitbucket.go +++ b/bitbucket/bitbucket.go @@ -6,7 +6,7 @@ package bitbucket import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Bitbucket's OAuth 2.0 endpoint. diff --git a/cern/cern.go b/cern/cern.go index 8be718078..0364d10bb 100644 --- a/cern/cern.go +++ b/cern/cern.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package cern provides constants for using OAuth2 to access CERN services. -package cern // import "golang.org/x/oauth2/cern" +package cern // import "github.com/cloudentity/oauth2/cern" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is CERN's OAuth 2.0 endpoint. diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 7a0b9ed10..55caa20bd 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -11,7 +11,7 @@ // server. // // See https://tools.ietf.org/html/rfc6749#section-4.4 -package clientcredentials // import "golang.org/x/oauth2/clientcredentials" +package clientcredentials // import "github.com/cloudentity/oauth2/clientcredentials" import ( "context" @@ -20,8 +20,8 @@ import ( "net/url" "strings" - "golang.org/x/oauth2" - "golang.org/x/oauth2/internal" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/internal" ) // Config describes a 2-legged OAuth2 flow, with both the diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 02a1c89a8..3f8216ede 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -13,7 +13,7 @@ import ( "net/url" "testing" - "golang.org/x/oauth2/internal" + "github.com/cloudentity/oauth2/internal" ) func newConf(serverURL string) *Config { diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go index 7cc37c876..2db328dec 100644 --- a/endpoints/endpoints.go +++ b/endpoints/endpoints.go @@ -8,7 +8,7 @@ package endpoints import ( "strings" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Amazon is the endpoint for Amazon. diff --git a/endpoints/endpoints_test.go b/endpoints/endpoints_test.go index 4ffa31429..92486678b 100644 --- a/endpoints/endpoints_test.go +++ b/endpoints/endpoints_test.go @@ -7,7 +7,7 @@ package endpoints import ( "testing" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) func TestAWSCognitoEndpoint(t *testing.T) { diff --git a/example_test.go b/example_test.go index fc2f793b2..6fe828fc1 100644 --- a/example_test.go +++ b/example_test.go @@ -11,7 +11,7 @@ import ( "net/http" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) func ExampleConfig() { diff --git a/facebook/facebook.go b/facebook/facebook.go index b0054e387..baa452706 100644 --- a/facebook/facebook.go +++ b/facebook/facebook.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package facebook provides constants for using OAuth2 to access Facebook. -package facebook // import "golang.org/x/oauth2/facebook" +package facebook // import "github.com/cloudentity/oauth2/facebook" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Facebook's OAuth 2.0 endpoint. diff --git a/fitbit/fitbit.go b/fitbit/fitbit.go index b31b82aca..9170a7bf2 100644 --- a/fitbit/fitbit.go +++ b/fitbit/fitbit.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package fitbit provides constants for using OAuth2 to access the Fitbit API. -package fitbit // import "golang.org/x/oauth2/fitbit" +package fitbit // import "github.com/cloudentity/oauth2/fitbit" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is the Fitbit API's OAuth 2.0 endpoint. diff --git a/foursquare/foursquare.go b/foursquare/foursquare.go index d2fa09902..7533cf154 100644 --- a/foursquare/foursquare.go +++ b/foursquare/foursquare.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package foursquare provides constants for using OAuth2 to access Foursquare. -package foursquare // import "golang.org/x/oauth2/foursquare" +package foursquare // import "github.com/cloudentity/oauth2/foursquare" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Foursquare's OAuth 2.0 endpoint. diff --git a/github/github.go b/github/github.go index f2978015b..0b01897f4 100644 --- a/github/github.go +++ b/github/github.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package github provides constants for using OAuth2 to access Github. -package github // import "golang.org/x/oauth2/github" +package github // import "github.com/cloudentity/oauth2/github" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Github's OAuth 2.0 endpoint. diff --git a/gitlab/gitlab.go b/gitlab/gitlab.go index 1231d75ac..3e8e5cb5e 100644 --- a/gitlab/gitlab.go +++ b/gitlab/gitlab.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package gitlab provides constants for using OAuth2 to access GitLab. -package gitlab // import "golang.org/x/oauth2/gitlab" +package gitlab // import "github.com/cloudentity/oauth2/gitlab" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is GitLab's OAuth 2.0 endpoint. diff --git a/go.mod b/go.mod index 955b627a1..31e167540 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module golang.org/x/oauth2 +module github.com/cloudentity/oauth2 go 1.17 diff --git a/google/appengine.go b/google/appengine.go index feb1157b1..971506468 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -8,7 +8,7 @@ import ( "context" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Set at init time by appengine_gen1.go. If nil, we're not on App Engine standard first generation (<= Go 1.9) or App Engine flexible. diff --git a/google/appengine_gen1.go b/google/appengine_gen1.go index 16c6c6b90..0c77add1f 100644 --- a/google/appengine_gen1.go +++ b/google/appengine_gen1.go @@ -15,7 +15,7 @@ import ( "strings" "sync" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" "google.golang.org/appengine" ) diff --git a/google/appengine_gen2_flex.go b/google/appengine_gen2_flex.go index a7e27b3d2..420eaa1ac 100644 --- a/google/appengine_gen2_flex.go +++ b/google/appengine_gen2_flex.go @@ -14,7 +14,7 @@ import ( "log" "sync" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) var logOnce sync.Once // only spam about deprecation once diff --git a/google/default.go b/google/default.go index 7ed02cd41..57af55560 100644 --- a/google/default.go +++ b/google/default.go @@ -15,8 +15,8 @@ import ( "runtime" "cloud.google.com/go/compute/metadata" - "golang.org/x/oauth2" - "golang.org/x/oauth2/authhandler" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/authhandler" ) // Credentials holds Google credentials, including "Application Default Credentials". diff --git a/google/doc.go b/google/doc.go index b3e7bc85c..a25eaad26 100644 --- a/google/doc.go +++ b/google/doc.go @@ -17,7 +17,7 @@ // // # OAuth2 Configs // -// Two functions in this package return golang.org/x/oauth2.Config values from Google credential +// Two functions in this package return github.com/cloudentity/oauth2.Config values from Google credential // data. Google supports two JSON formats for OAuth2 credentials: one is handled by ConfigFromJSON, // the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or // create an http.Client. @@ -81,4 +81,4 @@ // same as the one obtained from the oauth2.Config returned from ConfigFromJSON or // JWTConfigFromJSON, but the Credentials may contain additional information // that is useful is some circumstances. -package google // import "golang.org/x/oauth2/google" +package google // import "github.com/cloudentity/oauth2/google" diff --git a/google/downscope/downscoping.go b/google/downscope/downscoping.go index 3d4b5532d..6b5b19b70 100644 --- a/google/downscope/downscoping.go +++ b/google/downscope/downscoping.go @@ -44,7 +44,7 @@ import ( "net/url" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) var ( diff --git a/google/downscope/downscoping_test.go b/google/downscope/downscoping_test.go index d5adda19c..06c15c684 100644 --- a/google/downscope/downscoping_test.go +++ b/google/downscope/downscoping_test.go @@ -11,7 +11,7 @@ import ( "net/http/httptest" "testing" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) var ( diff --git a/google/downscope/tokenbroker_test.go b/google/downscope/tokenbroker_test.go index cb168785f..25e7263cd 100644 --- a/google/downscope/tokenbroker_test.go +++ b/google/downscope/tokenbroker_test.go @@ -8,10 +8,10 @@ import ( "context" "fmt" - "golang.org/x/oauth2/google" + "github.com/cloudentity/oauth2/google" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google/downscope" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/google/downscope" ) func ExampleNewTokenSource() { diff --git a/google/error.go b/google/error.go index d84dd0047..c0143d91c 100644 --- a/google/error.go +++ b/google/error.go @@ -7,7 +7,7 @@ package google import ( "errors" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // AuthenticationError indicates there was an error in the authentication flow. diff --git a/google/error_test.go b/google/error_test.go index cd60e9118..4a9e18fb1 100644 --- a/google/error_test.go +++ b/google/error_test.go @@ -8,7 +8,7 @@ import ( "net/http" "testing" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) func TestAuthenticationError_Temporary(t *testing.T) { diff --git a/google/example_test.go b/google/example_test.go index 3fc9cad3f..568caac7d 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -11,9 +11,9 @@ import ( "log" "net/http" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" - "golang.org/x/oauth2/jwt" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/google" + "github.com/cloudentity/oauth2/jwt" ) func ExampleDefaultClient() { diff --git a/google/google.go b/google/google.go index 8df0c493e..95829f331 100644 --- a/google/google.go +++ b/google/google.go @@ -14,9 +14,9 @@ import ( "time" "cloud.google.com/go/compute/metadata" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google/internal/externalaccount" - "golang.org/x/oauth2/jwt" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/google/internal/externalaccount" + "github.com/cloudentity/oauth2/jwt" ) // Endpoint is Google's OAuth 2.0 default endpoint. diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index e917195d5..62fd5b327 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) type awsSecurityCredentials struct { diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 9fc35535e..e7ba68bd3 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -14,7 +14,7 @@ import ( "strings" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // now aliases time.Now for testing diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 05e0127f0..207d4ffb3 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) const ( diff --git a/google/internal/externalaccount/clientauth.go b/google/internal/externalaccount/clientauth.go index 99987ce29..fff0c44db 100644 --- a/google/internal/externalaccount/clientauth.go +++ b/google/internal/externalaccount/clientauth.go @@ -9,7 +9,7 @@ import ( "net/http" "net/url" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. diff --git a/google/internal/externalaccount/clientauth_test.go b/google/internal/externalaccount/clientauth_test.go index bfb339d06..bd9138bfa 100644 --- a/google/internal/externalaccount/clientauth_test.go +++ b/google/internal/externalaccount/clientauth_test.go @@ -10,7 +10,7 @@ import ( "reflect" "testing" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) var clientID = "rbrgnognrhongo3bi4gb9ghg9g" diff --git a/google/internal/externalaccount/impersonate.go b/google/internal/externalaccount/impersonate.go index 54c8f209f..db97a7764 100644 --- a/google/internal/externalaccount/impersonate.go +++ b/google/internal/externalaccount/impersonate.go @@ -14,7 +14,7 @@ import ( "net/http" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // generateAccesstokenReq is used for service account impersonation diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go index e6fcae5fc..a262f462a 100644 --- a/google/internal/externalaccount/sts_exchange.go +++ b/google/internal/externalaccount/sts_exchange.go @@ -15,7 +15,7 @@ import ( "strconv" "strings" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // exchangeToken performs an oauth2 token exchange with the provided endpoint. diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go index df4d5ff4e..747b0d69d 100644 --- a/google/internal/externalaccount/sts_exchange_test.go +++ b/google/internal/externalaccount/sts_exchange_test.go @@ -13,7 +13,7 @@ import ( "net/url" "testing" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) var auth = clientAuthentication{ diff --git a/google/internal/externalaccount/urlcredsource.go b/google/internal/externalaccount/urlcredsource.go index 16dca6541..247548845 100644 --- a/google/internal/externalaccount/urlcredsource.go +++ b/google/internal/externalaccount/urlcredsource.go @@ -13,7 +13,7 @@ import ( "io/ioutil" "net/http" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) type urlCredentialSource struct { diff --git a/google/jwt.go b/google/jwt.go index e89e6ae17..ede6136af 100644 --- a/google/jwt.go +++ b/google/jwt.go @@ -10,9 +10,9 @@ import ( "strings" "time" - "golang.org/x/oauth2" - "golang.org/x/oauth2/internal" - "golang.org/x/oauth2/jws" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/internal" + "github.com/cloudentity/oauth2/jws" ) // JWTAccessTokenSourceFromJSON uses a Google Developers service account JSON diff --git a/google/jwt_test.go b/google/jwt_test.go index 5890ae9a7..2dbb8f2b5 100644 --- a/google/jwt_test.go +++ b/google/jwt_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "golang.org/x/oauth2/jws" + "github.com/cloudentity/oauth2/jws" ) var ( diff --git a/google/sdk.go b/google/sdk.go index 456224bc7..a6f0c0895 100644 --- a/google/sdk.go +++ b/google/sdk.go @@ -19,7 +19,7 @@ import ( "strings" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) type sdkCredentials struct { diff --git a/heroku/heroku.go b/heroku/heroku.go index 5b4fdb890..a42bac9d9 100644 --- a/heroku/heroku.go +++ b/heroku/heroku.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package heroku provides constants for using OAuth2 to access Heroku. -package heroku // import "golang.org/x/oauth2/heroku" +package heroku // import "github.com/cloudentity/oauth2/heroku" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Heroku's OAuth 2.0 endpoint. diff --git a/hipchat/hipchat.go b/hipchat/hipchat.go index 594fe072c..8732cd70d 100644 --- a/hipchat/hipchat.go +++ b/hipchat/hipchat.go @@ -3,14 +3,14 @@ // license that can be found in the LICENSE file. // Package hipchat provides constants for using OAuth2 to access HipChat. -package hipchat // import "golang.org/x/oauth2/hipchat" +package hipchat // import "github.com/cloudentity/oauth2/hipchat" import ( "encoding/json" "errors" - "golang.org/x/oauth2" - "golang.org/x/oauth2/clientcredentials" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/clientcredentials" ) // Endpoint is HipChat's OAuth 2.0 endpoint. diff --git a/instagram/instagram.go b/instagram/instagram.go index 75a74ebb9..db96eb691 100644 --- a/instagram/instagram.go +++ b/instagram/instagram.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package instagram provides constants for using OAuth2 to access Instagram. -package instagram // import "golang.org/x/oauth2/instagram" +package instagram // import "github.com/cloudentity/oauth2/instagram" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Instagram's OAuth 2.0 endpoint. diff --git a/internal/oauth2.go b/internal/oauth2.go index c0ab196cf..c9eb07090 100644 --- a/internal/oauth2.go +++ b/internal/oauth2.go @@ -35,3 +35,27 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) { } return parsed, nil } + +// ParsePublicKey converts the binary contents of a public key file +// to an *rsa.PrivateKey. It detects whether the private key is in a +// PEM container or not. If so, it extracts the the public key +// from PEM container before conversion. It only supports PEM +// containers with no passphrase. +func ParsePublicKey(key []byte) (*rsa.PublicKey, error) { + block, _ := pem.Decode(key) + if block != nil { + key = block.Bytes + } + parsedKey, err := x509.ParsePKIXPublicKey(key) + if err != nil { + parsedKey, err = x509.ParsePKCS1PublicKey(key) + if err != nil { + return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err) + } + } + parsed, ok := parsedKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("private key is invalid") + } + return parsed, nil +} diff --git a/internal/token.go b/internal/token.go index 355c38696..4a8bee700 100644 --- a/internal/token.go +++ b/internal/token.go @@ -102,7 +102,7 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { // Endpoint.AuthStyle. func RegisterBrokenAuthHeaderProvider(tokenURL string) {} -// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. +// AuthStyle is a copy of the github.com/cloudentity/oauth2 package's AuthStyle type. type AuthStyle int const ( diff --git a/jira/jira.go b/jira/jira.go index 814656e9e..fecccf1e9 100644 --- a/jira/jira.go +++ b/jira/jira.go @@ -19,7 +19,7 @@ import ( "strings" "time" - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // ClaimSet contains information about the JWT signature according diff --git a/jira/jira_test.go b/jira/jira_test.go index 07f6a6314..47d1d91f6 100644 --- a/jira/jira_test.go +++ b/jira/jira_test.go @@ -13,8 +13,8 @@ import ( "strings" "testing" - "golang.org/x/oauth2" - "golang.org/x/oauth2/jws" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/jws" ) func TestJWTFetch_JSONResponse(t *testing.T) { diff --git a/jws/jws.go b/jws/jws.go index 95015648b..00157a291 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -4,7 +4,7 @@ // Package jws provides a partial implementation // of JSON Web Signature encoding and decoding. -// It exists to support the golang.org/x/oauth2 package. +// It exists to support the github.com/cloudentity/oauth2 package. // // See RFC 7515. // @@ -12,7 +12,7 @@ // removed in the future. It exists for internal use only. // Please switch to another JWS package or copy this package into your own // source tree. -package jws // import "golang.org/x/oauth2/jws" +package jws // import "github.com/cloudentity/oauth2/jws" import ( "bytes" diff --git a/jwt/example_test.go b/jwt/example_test.go index 58503d80d..fe99c3fa8 100644 --- a/jwt/example_test.go +++ b/jwt/example_test.go @@ -7,7 +7,7 @@ package jwt_test import ( "context" - "golang.org/x/oauth2/jwt" + "github.com/cloudentity/oauth2/jwt" ) func ExampleJWTConfig() { diff --git a/jwt/jwt.go b/jwt/jwt.go index b2bf18298..9b4794edb 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -19,9 +19,9 @@ import ( "strings" "time" - "golang.org/x/oauth2" - "golang.org/x/oauth2/internal" - "golang.org/x/oauth2/jws" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/internal" + "github.com/cloudentity/oauth2/jws" ) var ( diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 9772dc520..f9e1913c5 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -15,8 +15,8 @@ import ( "strings" "testing" - "golang.org/x/oauth2" - "golang.org/x/oauth2/jws" + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/jws" ) var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- diff --git a/kakao/kakao.go b/kakao/kakao.go index 6d211260c..f1a1c3a87 100644 --- a/kakao/kakao.go +++ b/kakao/kakao.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package kakao provides constants for using OAuth2 to access Kakao. -package kakao // import "golang.org/x/oauth2/kakao" +package kakao // import "github.com/cloudentity/oauth2/kakao" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Kakao's OAuth 2.0 endpoint. diff --git a/linkedin/linkedin.go b/linkedin/linkedin.go index d3972771c..33af5f04b 100644 --- a/linkedin/linkedin.go +++ b/linkedin/linkedin.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package linkedin provides constants for using OAuth2 to access LinkedIn. -package linkedin // import "golang.org/x/oauth2/linkedin" +package linkedin // import "github.com/cloudentity/oauth2/linkedin" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is LinkedIn's OAuth 2.0 endpoint. diff --git a/mailchimp/mailchimp.go b/mailchimp/mailchimp.go index 647787ec6..208db920f 100644 --- a/mailchimp/mailchimp.go +++ b/mailchimp/mailchimp.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package mailchimp provides constants for using OAuth2 to access MailChimp. -package mailchimp // import "golang.org/x/oauth2/mailchimp" +package mailchimp // import "github.com/cloudentity/oauth2/mailchimp" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is MailChimp's OAuth 2.0 endpoint. diff --git a/mailru/mailru.go b/mailru/mailru.go index dddd9dd0f..f51dd297e 100644 --- a/mailru/mailru.go +++ b/mailru/mailru.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package mailru provides constants for using OAuth2 to access Mail.Ru. -package mailru // import "golang.org/x/oauth2/mailru" +package mailru // import "github.com/cloudentity/oauth2/mailru" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Mail.Ru's OAuth 2.0 endpoint. diff --git a/mediamath/mediamath.go b/mediamath/mediamath.go index 3ebce5da1..e44c64f55 100644 --- a/mediamath/mediamath.go +++ b/mediamath/mediamath.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package mediamath provides constants for using OAuth2 to access MediaMath. -package mediamath // import "golang.org/x/oauth2/mediamath" +package mediamath // import "github.com/cloudentity/oauth2/mediamath" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is MediaMath's OAuth 2.0 endpoint for production. diff --git a/microsoft/microsoft.go b/microsoft/microsoft.go index 3ffbc57a6..5e13b612c 100644 --- a/microsoft/microsoft.go +++ b/microsoft/microsoft.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package microsoft provides constants for using OAuth2 to access Windows Live ID. -package microsoft // import "golang.org/x/oauth2/microsoft" +package microsoft // import "github.com/cloudentity/oauth2/microsoft" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // LiveConnectEndpoint is Windows's Live ID OAuth 2.0 endpoint. diff --git a/nokiahealth/nokiahealth.go b/nokiahealth/nokiahealth.go index c181ccd0f..e112b0fff 100644 --- a/nokiahealth/nokiahealth.go +++ b/nokiahealth/nokiahealth.go @@ -6,7 +6,7 @@ package nokiahealth import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Nokia Health Mate's OAuth 2.0 endpoint. diff --git a/oauth2.go b/oauth2.go index 291df5c83..01b202d52 100644 --- a/oauth2.go +++ b/oauth2.go @@ -6,7 +6,7 @@ // OAuth2 authorized and authenticated HTTP requests, // as specified in RFC 6749. // It can additionally grant authorization with Bearer JWT. -package oauth2 // import "golang.org/x/oauth2" +package oauth2 // import "github.com/cloudentity/oauth2" import ( "bytes" @@ -17,7 +17,7 @@ import ( "strings" "sync" - "golang.org/x/oauth2/internal" + "github.com/cloudentity/oauth2/internal" ) // NoContext is the default context you should supply if not using @@ -37,7 +37,7 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. // For the client credentials 2-legged OAuth2 flow, see the clientcredentials -// package (https://golang.org/x/oauth2/clientcredentials). +// package (https://github.com/cloudentity/oauth2/clientcredentials). type Config struct { // ClientID is the application's ID. ClientID string diff --git a/oauth2_test.go b/oauth2_test.go index b7975e166..a95af6367 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "golang.org/x/oauth2/internal" + "github.com/cloudentity/oauth2/internal" ) type mockTransport struct { diff --git a/odnoklassniki/odnoklassniki.go b/odnoklassniki/odnoklassniki.go index c0d093ccc..cc79ce70c 100644 --- a/odnoklassniki/odnoklassniki.go +++ b/odnoklassniki/odnoklassniki.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package odnoklassniki provides constants for using OAuth2 to access Odnoklassniki. -package odnoklassniki // import "golang.org/x/oauth2/odnoklassniki" +package odnoklassniki // import "github.com/cloudentity/oauth2/odnoklassniki" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Odnoklassniki's OAuth 2.0 endpoint. diff --git a/paypal/paypal.go b/paypal/paypal.go index 2e713c53c..31bebed31 100644 --- a/paypal/paypal.go +++ b/paypal/paypal.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package paypal provides constants for using OAuth2 to access PayPal. -package paypal // import "golang.org/x/oauth2/paypal" +package paypal // import "github.com/cloudentity/oauth2/paypal" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is PayPal's OAuth 2.0 endpoint in live (production) environment. diff --git a/slack/slack.go b/slack/slack.go index 593d2f607..a980ea35d 100644 --- a/slack/slack.go +++ b/slack/slack.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package slack provides constants for using OAuth2 to access Slack. -package slack // import "golang.org/x/oauth2/slack" +package slack // import "github.com/cloudentity/oauth2/slack" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Slack's OAuth 2.0 endpoint. diff --git a/spotify/spotify.go b/spotify/spotify.go index c75416c00..c8d49a467 100644 --- a/spotify/spotify.go +++ b/spotify/spotify.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package spotify provides constants for using OAuth2 to access Spotify. -package spotify // import "golang.org/x/oauth2/spotify" +package spotify // import "github.com/cloudentity/oauth2/spotify" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Spotify's OAuth 2.0 endpoint. diff --git a/stackoverflow/stackoverflow.go b/stackoverflow/stackoverflow.go index 82711f777..6bed97880 100644 --- a/stackoverflow/stackoverflow.go +++ b/stackoverflow/stackoverflow.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package stackoverflow provides constants for using OAuth2 to access Stack Overflow. -package stackoverflow // import "golang.org/x/oauth2/stackoverflow" +package stackoverflow // import "github.com/cloudentity/oauth2/stackoverflow" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Stack Overflow's OAuth 2.0 endpoint. diff --git a/token.go b/token.go index 822720341..2dbb204c7 100644 --- a/token.go +++ b/token.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "golang.org/x/oauth2/internal" + "github.com/cloudentity/oauth2/internal" ) // expiryDelta determines how earlier a token should be considered diff --git a/transport.go b/transport.go index 90657915f..0f86580f9 100644 --- a/transport.go +++ b/transport.go @@ -63,7 +63,7 @@ var cancelOnce sync.Once // Deprecated: use contexts for cancellation instead. func (t *Transport) CancelRequest(req *http.Request) { cancelOnce.Do(func() { - log.Printf("deprecated: golang.org/x/oauth2: Transport.CancelRequest no longer does anything; use contexts") + log.Printf("deprecated: github.com/cloudentity/oauth2: Transport.CancelRequest no longer does anything; use contexts") }) } diff --git a/twitch/twitch.go b/twitch/twitch.go index 0838e7c15..d825b5d5c 100644 --- a/twitch/twitch.go +++ b/twitch/twitch.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package twitch provides constants for using OAuth2 to access Twitch. -package twitch // import "golang.org/x/oauth2/twitch" +package twitch // import "github.com/cloudentity/oauth2/twitch" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Twitch's OAuth 2.0 endpoint. diff --git a/uber/uber.go b/uber/uber.go index 5520a6455..b654784d5 100644 --- a/uber/uber.go +++ b/uber/uber.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package uber provides constants for using OAuth2 to access Uber. -package uber // import "golang.org/x/oauth2/uber" +package uber // import "github.com/cloudentity/oauth2/uber" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Uber's OAuth 2.0 endpoint. diff --git a/vk/vk.go b/vk/vk.go index bd8e15948..54f013acb 100644 --- a/vk/vk.go +++ b/vk/vk.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package vk provides constants for using OAuth2 to access VK.com. -package vk // import "golang.org/x/oauth2/vk" +package vk // import "github.com/cloudentity/oauth2/vk" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is VK's OAuth 2.0 endpoint. diff --git a/yahoo/yahoo.go b/yahoo/yahoo.go index 9fa78a23c..6fe05f69b 100644 --- a/yahoo/yahoo.go +++ b/yahoo/yahoo.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package yahoo provides constants for using OAuth2 to access Yahoo. -package yahoo // import "golang.org/x/oauth2/yahoo" +package yahoo // import "github.com/cloudentity/oauth2/yahoo" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is Yahoo's OAuth 2.0 endpoint. diff --git a/yandex/yandex.go b/yandex/yandex.go index 5ebf666d2..d48f6f5dc 100644 --- a/yandex/yandex.go +++ b/yandex/yandex.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package yandex provides constants for using OAuth2 to access Yandex APIs. -package yandex // import "golang.org/x/oauth2/yandex" +package yandex // import "github.com/cloudentity/oauth2/yandex" import ( - "golang.org/x/oauth2" + "github.com/cloudentity/oauth2" ) // Endpoint is the Yandex OAuth 2.0 endpoint. From 1ba557a3208fafc825db314439b7f28c2160c576 Mon Sep 17 00:00:00 2001 From: Konrad Holowinski Date: Wed, 26 Oct 2022 11:07:00 +0200 Subject: [PATCH 02/16] remove parse public key --- internal/oauth2.go | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/internal/oauth2.go b/internal/oauth2.go index c9eb07090..c0ab196cf 100644 --- a/internal/oauth2.go +++ b/internal/oauth2.go @@ -35,27 +35,3 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) { } return parsed, nil } - -// ParsePublicKey converts the binary contents of a public key file -// to an *rsa.PrivateKey. It detects whether the private key is in a -// PEM container or not. If so, it extracts the the public key -// from PEM container before conversion. It only supports PEM -// containers with no passphrase. -func ParsePublicKey(key []byte) (*rsa.PublicKey, error) { - block, _ := pem.Decode(key) - if block != nil { - key = block.Bytes - } - parsedKey, err := x509.ParsePKIXPublicKey(key) - if err != nil { - parsedKey, err = x509.ParsePKCS1PublicKey(key) - if err != nil { - return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err) - } - } - parsed, ok := parsedKey.(*rsa.PublicKey) - if !ok { - return nil, errors.New("private key is invalid") - } - return parsed, nil -} From 727becb69774b33af25eaa825e324fdad93f6016 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Mon, 31 Oct 2022 18:39:57 +0100 Subject: [PATCH 03/16] Private_key_jwt support (#1) --- .github/workflows/go.yml | 20 ++++ README.md | 62 ++++++---- advancedauth/advancedauth.go | 42 +++++++ advancedauth/privatekeyjwt.go | 81 +++++++++++++ advancedauth/privatekeyjwt_test.go | 155 +++++++++++++++++++++++++ advancedauth/utils_test.go | 74 ++++++++++++ clientcredentials/clientcredentials.go | 17 +++ go.mod | 2 + go.sum | 4 + oauth2.go | 5 + 10 files changed, 439 insertions(+), 23 deletions(-) create mode 100644 .github/workflows/go.yml create mode 100644 advancedauth/advancedauth.go create mode 100644 advancedauth/privatekeyjwt.go create mode 100644 advancedauth/privatekeyjwt_test.go create mode 100644 advancedauth/utils_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 000000000..a1fb19045 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,20 @@ +name: Go + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... diff --git a/README.md b/README.md index 1473e1296..6417205f7 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,52 @@ -# OAuth2 for Go +# OAuth2 for Go - extended with advanced authentication -[![Go Reference](https://pkg.go.dev/badge/golang.org/x/oauth2.svg)](https://pkg.go.dev/golang.org/x/oauth2) -[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2) +This repo is a drop-in replacement of `golang.org/x/oauth2` -oauth2 package contains a client implementation for OAuth 2.0 spec. +It extends the original library with additional authentication methods: + +- private_key_jwt ## Installation -~~~~ -go get golang.org/x/oauth2 -~~~~ +When using go modules you can run: + +`go mod edit -replace golang.org/x/oauth2 github.com/cloudentity/oauth2` + +## Usage + +When using any of the originally supported authentication methods, there's no need to change anything. +This library can be used as a drop-in replacement. + +For new authentication methods see: -Or you can manually git clone the repository to -`$(go env GOPATH)/src/golang.org/x/oauth2`. +### Private Key JWT -See pkg.go.dev for further documentation and examples. +#### Client credentials -* [pkg.go.dev/golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) -* [pkg.go.dev/golang.org/x/oauth2/google](https://pkg.go.dev/golang.org/x/oauth2/google) +```go +import ( + "context" + "time" -## Policy for new packages + "github.com/cloudentity/oauth2/advancedauth" + "github.com/cloudentity/oauth2/clientcredentials" +) +``` -We no longer accept new provider-specific packages in this repo if all -they do is add a single endpoint variable. If you just want to add a -single endpoint, add it to the -[pkg.go.dev/golang.org/x/oauth2/endpoints](https://pkg.go.dev/golang.org/x/oauth2/endpoints) -package. +```go + cfg := clientcredentials.Config{ + ClientID: "your client id", + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: "your PEM encoded private key", + Alg: advancedauth.RS256, + Exp: 30 * time.Second, + }, + } -## Report Issues / Send Patches + token, err := cfg.Token(context.Background()) +``` -This repository uses Gerrit for code changes. To learn how to submit changes to -this repository, see https://golang.org/doc/contribute.html. +## Implementation -The main issue tracker for the oauth2 repository is located at -https://github.com/golang/oauth2/issues. +This fork tries to limit changes to the original codebase to the minimum. +All the new major changes are implemented in the `advancedauth` package. diff --git a/advancedauth/advancedauth.go b/advancedauth/advancedauth.go new file mode 100644 index 000000000..1f4b60c95 --- /dev/null +++ b/advancedauth/advancedauth.go @@ -0,0 +1,42 @@ +package advancedauth + +import ( + "net/url" + + "github.com/cloudentity/oauth2" +) + +type Algorithm string + +const ( + RS256 Algorithm = "RS256" + RS384 Algorithm = "RS384" + RS512 Algorithm = "RS512" + + ES256 Algorithm = "ES256" + ES384 Algorithm = "ES384" + ES512 Algorithm = "ES512" +) + +type Config struct { + AuthStyle oauth2.AuthStyle + ClientID string + PrivateKeyAuth PrivateKeyAuth + TokenURL string +} + +func ExtendUrlValues(v url.Values, c Config) error { + if c.AuthStyle == oauth2.AuthStylePrivateKeyJWT { + jwtVals, err := privateKeyJWTAssertionVals(c) + if err != nil { + return err + } + + for key, vals := range jwtVals { + for _, val := range vals { + v.Set(key, val) + } + } + } + return nil +} diff --git a/advancedauth/privatekeyjwt.go b/advancedauth/privatekeyjwt.go new file mode 100644 index 000000000..c5f658210 --- /dev/null +++ b/advancedauth/privatekeyjwt.go @@ -0,0 +1,81 @@ +package advancedauth + +import ( + "fmt" + "net/url" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" +) + +const privateKeyJWTAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + +type PrivateKeyAuth struct { + // Key is a PEM formatted private key used to sign client_assertion + Key string + // Algorithm used to sign the client_assertion (see JWS) - default RS256 + Algorithm Algorithm + // Exp defines how long client_assertion is valid for - default 30 seconds + Exp time.Duration +} + +func privateKeyJWTAssertionVals(c Config) (url.Values, error) { + var ( + err error + assertion string + id uuid.UUID + key interface{} + token *jwt.Token + exp time.Duration + alg Algorithm + ) + + if id, err = uuid.NewUUID(); err != nil { + return url.Values{}, err + } + jti := id.String() + + exp = c.PrivateKeyAuth.Exp + if exp == 0*time.Second { + exp = 30 * time.Second + } + + claims := &jwt.RegisteredClaims{ + Issuer: c.ClientID, + Subject: c.ClientID, + Audience: []string{strings.TrimSuffix(c.TokenURL, "/token")}, + ID: jti, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(exp)), + } + + alg = c.PrivateKeyAuth.Algorithm + if alg == "" { + alg = RS256 + } + + switch alg { + case RS256, RS384, RS512: + if key, err = jwt.ParseRSAPrivateKeyFromPEM([]byte(c.PrivateKeyAuth.Key)); err != nil { + return url.Values{}, fmt.Errorf("could not parse private key from PEM %s", alg) + } + case ES256, ES384, ES512: + if key, err = jwt.ParseECPrivateKeyFromPEM([]byte(c.PrivateKeyAuth.Key)); err != nil { + return url.Values{}, fmt.Errorf("could not parse private key from PEM %s", alg) + } + default: + return url.Values{}, fmt.Errorf("unsupported algorithm %s", alg) + } + + token = jwt.NewWithClaims(jwt.GetSigningMethod(string(alg)), claims) + + if assertion, err = token.SignedString(key); err != nil { + return url.Values{}, err + } + + return url.Values{ + "client_assertion": []string{assertion}, + "client_assertion_type": []string{privateKeyJWTAssertionType}, + }, nil +} diff --git a/advancedauth/privatekeyjwt_test.go b/advancedauth/privatekeyjwt_test.go new file mode 100644 index 000000000..a82e3d169 --- /dev/null +++ b/advancedauth/privatekeyjwt_test.go @@ -0,0 +1,155 @@ +package advancedauth_test + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth" + "github.com/cloudentity/oauth2/clientcredentials" + "github.com/golang-jwt/jwt/v4" +) + +const ( + privateKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDYRpq7yP3IaRxFjD9i1VWAFMHgLikJgGQaScg5S9XS3INwYz+E +ZtXrg6++HKyHjqEUeKT+2IZHSJPhOHdKaxh7KCci31MXHtWSG8xMaikKWyLPXjmU +kqONQHOD7XvECqQ8KGkrZ5BTIkVa7KA6aXlYoc3zQpOfbf+wx3/57uuDQQIDAQAB +AoGADKfdCB4T07Vq5Rr23pazQSJ10eOBnT+5G9yzbb7lTUiAHISCRAIshHKZRxuw +cOJExMjmhs8u1F8H4EcIm/82WGsMegCLrS8Y1zW2goiNqIh4QBGHudgvmrXQFz+T +9euhREf4gq7npIHW/ahjCMeEc2Yom4wQC6QJ0bOUu/hiqm0CQQDzIEpFZQnYYMzn +99lk4Qnxh1l0UzTJNNKVidEXi3iHam2ztTkE5mIWlZKHvg5DHzOmvzPKYzFS2YS+ +0RACf2/PAkEA47pX1Qc8axoqTBSELA1i3ZKc+qs0mmh2FXcDB2OcpUH00sXLCjGO +r3d57vNRKUYfu7VAQliis8iq5+DPA4sP7wJBAOyLhxd7VZfbnqE2qKGYvcbrzCH8 +bogwx45Ml03UGcGO0Asfj8lvqRGWFwnQ5SlzKxraPrZzyeJ01c2dtHjpqksCQCj1 +G9Txnzk4FIFoczklEzH8q4UeA7D9trc3l3Ddxo+mZC0Aa/siXKJMX77NPjypIw30 +lGEaZfDl128q7LCbczsCQGIBBN0TAwxfYstKeD7g7GXG8yD10LlmB3FCBdQjoBaW +tfeljbt+hNJU/3NIvDhYujEfG2d9cmBZulMRY7gh40Y= +-----END RSA PRIVATE KEY-----` + + publicKey = `-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDYRpq7yP3IaRxFjD9i1VWAFMHg +LikJgGQaScg5S9XS3INwYz+EZtXrg6++HKyHjqEUeKT+2IZHSJPhOHdKaxh7KCci +31MXHtWSG8xMaikKWyLPXjmUkqONQHOD7XvECqQ8KGkrZ5BTIkVa7KA6aXlYoc3z +QpOfbf+wx3/57uuDQQIDAQAB +-----END PUBLIC KEY-----` + + privateECDSAKey = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIMlmB8ys8+Sp4b0zSzghVD9q9GtljXTwI58f6sGJoFRQoAoGCCqGSM49 +AwEHoUQDQgAEO1sWioJjxNghnKRcH1eHMCTreC2FvVWVDgE2dqe84TeXtbkAUosr +9EdTaTI96qG8xnCEKg3QLnCRuJj54SqpSQ== +-----END EC PRIVATE KEY-----` + + publicECDSAKey = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEO1sWioJjxNghnKRcH1eHMCTreC2F +vVWVDgE2dqe84TeXtbkAUosr9EdTaTI96qG8xnCEKg3QLnCRuJj54SqpSQ== +-----END PUBLIC KEY-----` +) + +func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { + rsaPubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey)) + if err != nil { + t.Error("could not parse rsa public key") + } + ecdsaPubKey, err := jwt.ParseECPublicKeyFromPEM([]byte(publicECDSAKey)) + if err != nil { + t.Error("could not parse ecdsa public key") + } + tcs := []struct { + title string + config clientcredentials.Config + publicKey interface{} + }{ + { + title: "RSA", + config: clientcredentials.Config{ + ClientID: "CLIENT_ID", + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: privateKey, + }, + Scopes: []string{"scope1", "scope2"}, + EndpointParams: url.Values{"audience": {"audience1"}}, + }, + publicKey: rsaPubKey, + }, + { + title: "ECDSA", + config: clientcredentials.Config{ + ClientID: "CLIENT_ID", + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: privateECDSAKey, + Algorithm: "ES256", + }, + Scopes: []string{"scope1", "scope2"}, + EndpointParams: url.Values{"audience": {"audience1"}}, + }, + publicKey: ecdsaPubKey, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + var serverURL string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectURL(tt, r, "/token") + expectHeader(tt, r, "Authorization", "") + expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") + expectFormParam(tt, r, "client_id", "") + expectFormParam(tt, r, "client_secret", "") + expectFormParam(tt, r, "grant_type", "client_credentials") + expectFormParam(tt, r, "scope", "scope1 scope2") + expectFormParam(tt, r, "client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + + assertion := r.FormValue("client_assertion") + claims := jwt.RegisteredClaims{} + token, err := jwt.ParseWithClaims(assertion, &claims, func(token *jwt.Token) (interface{}, error) { + return tc.publicKey, nil + }) + if err != nil { + tt.Errorf("could not parse assertion %+v", err) + } + if !token.Valid { + tt.Error("invalid assertion token") + } + + expectStringsEqual(tt, "CLIENT_ID", claims.Issuer) + expectStringsEqual(tt, "CLIENT_ID", claims.Subject) + + // uuid v4 like + expectTrue(tt, len(claims.ID) == 36) + + expectTrue(tt, time.Now().Unix() < claims.ExpiresAt.Unix()) + expectStringsEqual(tt, serverURL, claims.Audience[0]) + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + if err != nil { + tt.Errorf("could not write body") + } + })) + serverURL = ts.URL + defer ts.Close() + conf := tc.config + conf.TokenURL = serverURL + "/token" + tok, err := conf.Token(context.Background()) + if err != nil { + tt.Error(err) + } + expectAccessToken(tt, &oauth2.Token{ + AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", + TokenType: "bearer", + RefreshToken: "", + Expiry: time.Time{}, + }, tok) + }) + } + +} diff --git a/advancedauth/utils_test.go b/advancedauth/utils_test.go new file mode 100644 index 000000000..1b17afeca --- /dev/null +++ b/advancedauth/utils_test.go @@ -0,0 +1,74 @@ +package advancedauth_test + +import ( + "io/ioutil" + "net/http" + "testing" + + "github.com/cloudentity/oauth2" +) + +func expectHeader(t *testing.T, r *http.Request, header string, expected string) { + actual := r.Header.Get(header) + if actual != expected { + t.Errorf("Expected header %s to be %s, got %s", header, expected, actual) + } +} + +func expectURL(t *testing.T, r *http.Request, expected string) { + actual := r.URL.String() + if actual != expected { + t.Errorf("Expected url to be %s, got %s", expected, actual) + } +} + +func expectBody(t *testing.T, r *http.Request, expected string) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + r.Body.Close() + } + if err != nil { + t.Errorf("failed reading request body: %s.", err) + } + actual := string(body) + if actual != expected { + t.Errorf("Expected body to be %s, got %s", expected, actual) + } +} + +func expectAccessToken(t *testing.T, expected *oauth2.Token, actual *oauth2.Token) { + if !actual.Valid() { + t.Fatalf("token invalid. got: %+v", actual) + } + if actual.AccessToken != expected.AccessToken { + t.Errorf("Access token = %q; want %q", actual.AccessToken, expected.AccessToken) + } + if actual.TokenType != expected.TokenType { + t.Errorf("token type = %q; want %q", actual.TokenType, expected.TokenType) + } +} + +func expectFormParam(t *testing.T, r *http.Request, param string, expected string) { + actual := r.FormValue(param) + if actual != expected { + t.Errorf("Expected form param %s to be %s, got %s", param, expected, actual) + } +} + +func expectStringsEqual(t *testing.T, expected string, actual string) { + if actual != expected { + t.Errorf("Expected %s and %s to be equal", expected, actual) + } +} + +func expectStringNonEmpty(t *testing.T, actual string) { + if actual == "" { + t.Errorf("Expected not empty %s", actual) + } +} + +func expectTrue(t *testing.T, actual bool) { + if !actual { + t.Errorf("Expected true %t", actual) + } +} diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 55caa20bd..fba08e403 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth" "github.com/cloudentity/oauth2/internal" ) @@ -47,6 +48,10 @@ type Config struct { // client ID & client secret sent. The zero value means to // auto-detect. AuthStyle oauth2.AuthStyle + + // PrivateKeyAuth stores configuration options for private_key_jwt + // client authentication method described in OpenID Connect spec. + PrivateKeyAuth advancedauth.PrivateKeyAuth } // Token uses client credentials to retrieve a token. @@ -94,6 +99,18 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { if len(c.conf.Scopes) > 0 { v.Set("scope", strings.Join(c.conf.Scopes, " ")) } + // not client_secret nor auto_detect + if c.conf.AuthStyle > 2 { + var err error + if err = advancedauth.ExtendUrlValues(v, advancedauth.Config{ + AuthStyle: c.conf.AuthStyle, + ClientID: c.conf.ClientID, + PrivateKeyAuth: c.conf.PrivateKeyAuth, + TokenURL: c.conf.TokenURL, + }); err != nil { + return nil, err + } + } for k, p := range c.conf.EndpointParams { // Allow grant_type to be overridden to allow interoperability with // non-compliant implementations. diff --git a/go.mod b/go.mod index 5ede3cc78..f7680d536 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.17 require ( cloud.google.com/go/compute/metadata v0.2.0 + github.com/golang-jwt/jwt/v4 v4.4.2 github.com/google/go-cmp v0.5.8 + github.com/google/uuid v1.1.2 golang.org/x/net v0.1.0 google.golang.org/appengine v1.6.7 ) diff --git a/go.sum b/go.sum index 118975e24..25f58b344 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ cloud.google.com/go/compute/metadata v0.2.0 h1:nBbNSZyDpkNlo3DepaaLKVuO7ClyifSAmNloSCZrHnQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= +github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= @@ -7,6 +9,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/oauth2.go b/oauth2.go index 01b202d52..4aaa4d067 100644 --- a/oauth2.go +++ b/oauth2.go @@ -97,6 +97,11 @@ const ( // using HTTP Basic Authorization. This is an optional style // described in the OAuth2 RFC 6749 section 2.3.1. AuthStyleInHeader AuthStyle = 2 + + // AuthStylePrivateKeyJWT sends a JWT assertion + // signed using the private key + // described in OpenID Connect Core + AuthStylePrivateKeyJWT AuthStyle = 3 ) var ( From 788633f0a473e36d5ddd354bf6477d9d1000c35f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Tue, 1 Nov 2022 10:43:33 +0100 Subject: [PATCH 04/16] TLS client authentication support (#2) --- README.md | 34 +++- advancedauth/advancedauth.go | 12 ++ advancedauth/tls.go | 54 ++++++ advancedauth/tls_test.go | 254 +++++++++++++++++++++++++ advancedauth/utils_test.go | 20 +- clientcredentials/clientcredentials.go | 14 +- oauth2.go | 3 + 7 files changed, 377 insertions(+), 14 deletions(-) create mode 100644 advancedauth/tls.go create mode 100644 advancedauth/tls_test.go diff --git a/README.md b/README.md index 6417205f7..5a40c5982 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,8 @@ This repo is a drop-in replacement of `golang.org/x/oauth2` It extends the original library with additional authentication methods: - private_key_jwt +- tls_client_auth +- self_signed_tls_client_auth ## Installation @@ -17,7 +19,7 @@ When using go modules you can run: When using any of the originally supported authentication methods, there's no need to change anything. This library can be used as a drop-in replacement. -For new authentication methods see: +For new authentication methods see the examples below: ### Private Key JWT @@ -36,6 +38,7 @@ import ( ```go cfg := clientcredentials.Config{ ClientID: "your client id", + AuthStyle: oauth2.AuthStylePrivateKeyJWT, PrivateKeyAuth: advancedauth.PrivateKeyAuth{ Key: "your PEM encoded private key", Alg: advancedauth.RS256, @@ -46,6 +49,35 @@ import ( token, err := cfg.Token(context.Background()) ``` +### TLS Auth + +Both `tls_client_auth` and `self_signed_tls_client_auth` are handled with `TLSAuth` + +#### Client credentials + +```go +import ( + "context" + "time" + + "github.com/cloudentity/oauth2/advancedauth" + "github.com/cloudentity/oauth2/clientcredentials" +) +``` + +```go + cfg := clientcredentials.Config{ + ClientID: "your client id", + AuthStyle: oauth2.AuthStyleTLS, + TLSAuth: advancedauth.TLSAuth{ + Key: "your certificate PEM encoded private key", + Certificate: "your PEM encoded TLS certificate", + }, + } + + token, err := cfg.Token(context.Background()) +``` + ## Implementation This fork tries to limit changes to the original codebase to the minimum. diff --git a/advancedauth/advancedauth.go b/advancedauth/advancedauth.go index 1f4b60c95..fd6f60d45 100644 --- a/advancedauth/advancedauth.go +++ b/advancedauth/advancedauth.go @@ -1,6 +1,7 @@ package advancedauth import ( + "context" "net/url" "github.com/cloudentity/oauth2" @@ -22,6 +23,7 @@ type Config struct { AuthStyle oauth2.AuthStyle ClientID string PrivateKeyAuth PrivateKeyAuth + TLSAuth TLSAuth TokenURL string } @@ -38,5 +40,15 @@ func ExtendUrlValues(v url.Values, c Config) error { } } } + if c.AuthStyle == oauth2.AuthStyleTLS { + v.Set("client_id", c.ClientID) + } return nil } + +func ExtendContext(ctx context.Context, c Config) (context.Context, error) { + if c.AuthStyle == oauth2.AuthStyleTLS { + return extendContextWithTLSClient(ctx, c) + } + return ctx, nil +} diff --git a/advancedauth/tls.go b/advancedauth/tls.go new file mode 100644 index 000000000..e9ccd8374 --- /dev/null +++ b/advancedauth/tls.go @@ -0,0 +1,54 @@ +package advancedauth + +import ( + "context" + "crypto/tls" + "errors" + "net/http" + + "github.com/cloudentity/oauth2" +) + +type TLSAuth struct { + // Key is the private key for client TLS certificate + Key string + // Certificate is the client TLS certificate + Certificate string +} + +func extendContextWithTLSClient(ctx context.Context, c Config) (context.Context, error) { + var ( + hc *http.Client + ok bool + cert tls.Certificate + err error + tr *http.Transport + ) + if ctx == nil { + ctx = context.Background() + } + + if ctx.Value(oauth2.HTTPClient) == nil { + hc = http.DefaultClient + } else if hc, ok = ctx.Value(oauth2.HTTPClient).(*http.Client); !ok { + return nil, errors.New("client of type *http.Client required in context") + } + + if cert, err = tls.X509KeyPair([]byte(c.TLSAuth.Certificate), []byte(c.TLSAuth.Key)); err != nil { + return nil, err + } + + if hc.Transport == nil { + tr = &http.Transport{} + } else if tr, ok = hc.Transport.(*http.Transport); !ok { + return nil, errors.New("transport of type *http.Transport required in context") + } + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{} + } + tr.TLSClientConfig.Certificates = []tls.Certificate{cert} + hc.Transport = tr + + return context.WithValue(ctx, oauth2.HTTPClient, hc), nil + +} diff --git a/advancedauth/tls_test.go b/advancedauth/tls_test.go new file mode 100644 index 000000000..1ac5d6674 --- /dev/null +++ b/advancedauth/tls_test.go @@ -0,0 +1,254 @@ +package advancedauth_test + +import ( + "context" + "crypto/tls" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth" + "github.com/cloudentity/oauth2/clientcredentials" +) + +const ( + key = `-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC0uhESy4URdqwo +8Hbus5UjdxQom0zQj7jw4bcZ2Z4X0HLJbmbDZdwIaoOWfSjYu9VYPkE04/+KnBOh +XMpA8DfcyS+XVPPTAEFI7KH9RF7BTMjSxB32Huwz9hMHqiPxJx1R+dTSWSC61+GX +Dq+cLHGeQq4Cqxxf0nnGmgpnT26GtiG/QZzE0IdlxaK68BzFk3syNzVFE8Om6yzx +ET7L5/p6igFrj22enjbYimtcSuHM2k16n0MSipBL2v1scheifGN0P+po118IRuX2 +mU8WH5Z8eyInWf857sNEHuFoCkuegJFVkzkuzxZz/F+cT1Znfq0x17ssnL9SFDk4 +XpyNKTqPAgMBAAECggEAULaMq30zV8JNTxddtmuDnswut5fsLXUSnpnf4W6cOXyB +1040HO4f365aSFprZKg2tutOyeVNmkTsS3OabHgcKsG7PHXXUxPZFE2CZw8i1meJ +hP/LdcEHsokipJiq5qeWY6cVEkB16pxBhuorKa97qreS6WQsDut8MWNYZB1Iemaa +HjioQZ7SpUUUyr3XNuvoaPViymGou6DYLaIMg0zklOrfigu1Qb4XdtWtbdi3AWcr +dVNO/N8Y19pJGqpJZ0FlqT/G8es10prAJGPAy4O/RxsLEfOSlZHe1Oj5V63B5h6R +KPwzSRM03gqHG0qruhr2seQN2UvJSRJNz3a2q7siGQKBgQDsXvcohxXoVkv2yvq4 +D9QmQxU3/zHPZhnFNpZ9p3a4AHvmTFyTErTPrZn+QW/l9VvyKGctezR9/SMTLmsQ +dz8Pnbqoukp2Vo/zNK1HEf3Iy5/lVZtd4ErfFCKpWYkNEXX43RQ2qvNt/XkkuIIg +mijoKxBfiwKD8sGB2B8owHCi6wKBgQDDvCglc1yPQ3dzEcaMOoABKWdH7Q72Xgjr +rpmO5lATn6kvcwgAjf/EEIGSQVjoY3zhOZ4J/eV7G6NTg9sRVhcWtkt1UtVv1BwE +Cg4P6W7hCg8GF8Egh/dYtarx19juZkXk5HNSe0PEgrpbjzdxx0s/2HE1JwziVa3q +qJFV4gd17QKBgQCS81dlctZD46LGg9rro6uZPgtrDNTCxA8xdIaLCBneuy5MNx02 +smKG2r7qO3R92tSW8Fd1ByvTSBUOT8VwLzKdWso5K9gvShGkehNgI+dLdoyp31cA +PflORw5liqyR21Ekrw1qD03YC8XM9oiwDCdyb5N2Us31im6TcvGsPDfKkQKBgQCF +Ok0ZMKyP1xw29qJuUGNQZx4llvXYO6lWwkFDQwC+Wq6N3X5U4lJ04cdQBaq+gvk9 +VDp+EpNgeC9zaQxzgGW2z94MvZUJyRZIqY9oxTrzciVHwGN0ARgbCYyRkJnXq0Vn +xxe3zK8T0ueF6rWSfFR74Jct1qauaCM41gQWsQLjAQKBgGfnF99nLe1iI4AZgLIQ +nYgCV65/bmbgX5gkMbDMxZzZYNWg15YuB5Ir+cf20pCwO5EmoLpn7KGpEeED4+/z +2PZrF4bcjmEhYT5O2Y1Wn1oB84uug9c+ME7yiU30g1FttURZuLtzUxASFP2o0l7r +zbSntKWbvm2qk39YKulrEnoh +-----END PRIVATE KEY-----` + cert = `-----BEGIN CERTIFICATE----- +MIIDHTCCAgUCFE+Ha5QgryApfoCjSX564o0JoGYIMA0GCSqGSIb3DQEBCwUAMCcx +CzAJBgNVBAYTAlVTMRgwFgYDVQQDDA9FeGFtcGxlLVJvb3QtQ0EwIBcNMjIxMDMx +MTgxODQ0WhgPMjA3NzA4MDMxODE4NDRaMG0xCzAJBgNVBAYTAlVTMRIwEAYDVQQI +DAlZb3VyU3RhdGUxETAPBgNVBAcMCFlvdXJDaXR5MR0wGwYDVQQKDBRFeGFtcGxl +LUNlcnRpZmljYXRlczEYMBYGA1UEAwwPbG9jYWxob3N0LmxvY2FsMIIBIjANBgkq +hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtLoREsuFEXasKPB27rOVI3cUKJtM0I+4 +8OG3GdmeF9ByyW5mw2XcCGqDln0o2LvVWD5BNOP/ipwToVzKQPA33Mkvl1Tz0wBB +SOyh/URewUzI0sQd9h7sM/YTB6oj8ScdUfnU0lkgutfhlw6vnCxxnkKuAqscX9J5 +xpoKZ09uhrYhv0GcxNCHZcWiuvAcxZN7Mjc1RRPDpuss8RE+y+f6eooBa49tnp42 +2IprXErhzNpNep9DEoqQS9r9bHIXonxjdD/qaNdfCEbl9plPFh+WfHsiJ1n/Oe7D +RB7haApLnoCRVZM5Ls8Wc/xfnE9WZ36tMde7LJy/UhQ5OF6cjSk6jwIDAQABMA0G +CSqGSIb3DQEBCwUAA4IBAQCBeRGIRS2MljdbgExv5KEND4OhEj2kuuES1zzTQjgs +EO6G3RlFRU9dFz9WDsLSeegY/4Y8BwR6kA3IpmLVnfmn4odWHhLv+JCDo7TG+R6c +3JnHbLuimcMLnGVVdUzAxQz09bNxYhCqUEla/ji0GeSxg8j8ofxtE7qihODV5dQv +gx3Ef/WxZTy08hd8pKxA8dg/VzechNRngFpINXUnGsX699pSoPWfHQoyZprvWjE7 +QDac6VgTzy/KPfaf9vi3MiXJyjJOuGO3+SL1PhR712qRGg9Y+kccNUlL4OfrLJpm +qobZlvUYUfAYcyJVtjas3vPoQHVCcbq7hdbso5FrLyPK +-----END CERTIFICATE-----` +) + +func TestTLS_ClientCredentials(t *testing.T) { + tcs := []struct { + title string + config clientcredentials.Config + }{ + { + title: "TLS", + config: clientcredentials.Config{ + ClientID: "CLIENT_ID", + AuthStyle: oauth2.AuthStyleTLS, + TLSAuth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + Scopes: []string{"scope1", "scope2"}, + EndpointParams: url.Values{"audience": {"audience1"}}, + }, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + var serverURL string + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectURL(tt, r, "/token") + expectHeader(tt, r, "Authorization", "") + expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") + expectFormParam(tt, r, "client_id", "CLIENT_ID") + expectFormParam(tt, r, "client_secret", "") + expectFormParam(tt, r, "grant_type", "client_credentials") + + cert := r.TLS.PeerCertificates[0] + expectStringsEqual(tt, "Example-Root-CA", cert.Issuer.CommonName) + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err := w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + if err != nil { + tt.Errorf("could not write body") + } + })) + + ts.TLS = &tls.Config{ + ClientAuth: tls.RequestClientCert, + } + + ts.StartTLS() + serverURL = ts.URL + defer ts.Close() + conf := tc.config + conf.TokenURL = serverURL + "/token" + + _, err := conf.Token(context.Background()) + // context.Background() will fail as the server cert is not trusted + // err == nil checks if there are no panics + if err == nil { + tt.Errorf("expected Token to fail with invalid server cert") + } + + client := ts.Client() + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + tok, err := conf.Token(ctx) + if err != nil { + tt.Error(err) + } + + expectAccessToken(tt, &oauth2.Token{ + AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", + TokenType: "bearer", + RefreshToken: "", + Expiry: time.Time{}, + }, tok) + }) + } + +} + +type fakeRoundTripper struct{} + +func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + return nil, nil +} + +func TestExtendContext(t *testing.T) { + + tcs := []struct { + title string + ctx context.Context + errorExpected bool + auth advancedauth.TLSAuth + assertTransport func(ttt *testing.T, t *http.Transport) + }{ + { + title: "background context", + ctx: context.Background(), + errorExpected: false, + auth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + }, + { + title: "invalid cert", + ctx: context.Background(), + errorExpected: true, + auth: advancedauth.TLSAuth{ + Key: key, + Certificate: "random", + }, + }, + { + title: "non *http.Client client", + ctx: context.WithValue(context.Background(), oauth2.HTTPClient, struct{}{}), + errorExpected: true, + auth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + }, + { + title: "non *http.Transport transport", + ctx: context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{ + Transport: &fakeRoundTripper{}, + }), + errorExpected: true, + auth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + }, + { + title: "no transport configured", + ctx: context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{}), + errorExpected: false, + auth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + }, + { + title: "configured transport", + ctx: context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{ + Transport: &http.Transport{ + IdleConnTimeout: 10 * time.Second, + }, + }), + errorExpected: false, + auth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + assertTransport: func(ttt *testing.T, tr *http.Transport) { + expectTrue(ttt, tr.IdleConnTimeout == 10*time.Second) + }, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + config := advancedauth.Config{ + AuthStyle: oauth2.AuthStyleTLS, + ClientID: "random", + TLSAuth: tc.auth, + TokenURL: "random", + } + ctx, err := advancedauth.ExtendContext(tc.ctx, config) + if tc.errorExpected && err == nil { + tt.Errorf("expected error") + } else if !tc.errorExpected && err != nil { + tt.Fatalf("unexpected error %+v", err) + } else if !tc.errorExpected && err == nil { + c := ctx.Value(oauth2.HTTPClient) + expectTrue(tt, c != nil) + hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client) + expectTrue(tt, ok) + tr, ok := hc.Transport.(*http.Transport) + expectTrue(tt, ok) + certs := tr.TLSClientConfig.Certificates + expectTrue(tt, len(certs) == 1) + if tc.assertTransport != nil { + tc.assertTransport(tt, tr) + } + } + }) + } +} diff --git a/advancedauth/utils_test.go b/advancedauth/utils_test.go index 1b17afeca..c7a0db34f 100644 --- a/advancedauth/utils_test.go +++ b/advancedauth/utils_test.go @@ -11,14 +11,14 @@ import ( func expectHeader(t *testing.T, r *http.Request, header string, expected string) { actual := r.Header.Get(header) if actual != expected { - t.Errorf("Expected header %s to be %s, got %s", header, expected, actual) + t.Fatalf("Expected header %s to be %s, got %s", header, expected, actual) } } func expectURL(t *testing.T, r *http.Request, expected string) { actual := r.URL.String() if actual != expected { - t.Errorf("Expected url to be %s, got %s", expected, actual) + t.Fatalf("Expected url to be %s, got %s", expected, actual) } } @@ -28,11 +28,11 @@ func expectBody(t *testing.T, r *http.Request, expected string) { r.Body.Close() } if err != nil { - t.Errorf("failed reading request body: %s.", err) + t.Fatalf("failed reading request body: %s.", err) } actual := string(body) if actual != expected { - t.Errorf("Expected body to be %s, got %s", expected, actual) + t.Fatalf("Expected body to be %s, got %s", expected, actual) } } @@ -41,34 +41,34 @@ func expectAccessToken(t *testing.T, expected *oauth2.Token, actual *oauth2.Toke t.Fatalf("token invalid. got: %+v", actual) } if actual.AccessToken != expected.AccessToken { - t.Errorf("Access token = %q; want %q", actual.AccessToken, expected.AccessToken) + t.Fatalf("Access token = %q; want %q", actual.AccessToken, expected.AccessToken) } if actual.TokenType != expected.TokenType { - t.Errorf("token type = %q; want %q", actual.TokenType, expected.TokenType) + t.Fatalf("token type = %q; want %q", actual.TokenType, expected.TokenType) } } func expectFormParam(t *testing.T, r *http.Request, param string, expected string) { actual := r.FormValue(param) if actual != expected { - t.Errorf("Expected form param %s to be %s, got %s", param, expected, actual) + t.Fatalf("Expected form param %s to be %s, got %s", param, expected, actual) } } func expectStringsEqual(t *testing.T, expected string, actual string) { if actual != expected { - t.Errorf("Expected %s and %s to be equal", expected, actual) + t.Fatalf("Expected %s and %s to be equal", expected, actual) } } func expectStringNonEmpty(t *testing.T, actual string) { if actual == "" { - t.Errorf("Expected not empty %s", actual) + t.Fatalf("Expected not empty %s", actual) } } func expectTrue(t *testing.T, actual bool) { if !actual { - t.Errorf("Expected true %t", actual) + t.Fatalf("Expected true %t", actual) } } diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index fba08e403..4d8a9b46e 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -52,6 +52,10 @@ type Config struct { // PrivateKeyAuth stores configuration options for private_key_jwt // client authentication method described in OpenID Connect spec. PrivateKeyAuth advancedauth.PrivateKeyAuth + + // TLSAuth stores the configuration options for tls_client_auth and self_signed_tls_client_auth + // client authentication methods described in RFC 8705 + TLSAuth advancedauth.TLSAuth } // Token uses client credentials to retrieve a token. @@ -102,12 +106,17 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { // not client_secret nor auto_detect if c.conf.AuthStyle > 2 { var err error - if err = advancedauth.ExtendUrlValues(v, advancedauth.Config{ + cfg := advancedauth.Config{ AuthStyle: c.conf.AuthStyle, ClientID: c.conf.ClientID, PrivateKeyAuth: c.conf.PrivateKeyAuth, + TLSAuth: c.conf.TLSAuth, TokenURL: c.conf.TokenURL, - }); err != nil { + } + if err = advancedauth.ExtendUrlValues(v, cfg); err != nil { + return nil, err + } + if c.ctx, err = advancedauth.ExtendContext(c.ctx, cfg); err != nil { return nil, err } } @@ -119,7 +128,6 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { } v[k] = p } - tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { diff --git a/oauth2.go b/oauth2.go index 4aaa4d067..38ec79e53 100644 --- a/oauth2.go +++ b/oauth2.go @@ -102,6 +102,9 @@ const ( // signed using the private key // described in OpenID Connect Core AuthStylePrivateKeyJWT AuthStyle = 3 + + // AuthStyleTLS + AuthStyleTLS AuthStyle = 4 ) var ( From 2828a74a1d80882a3bf3a485873c0eb2b5220ed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Tue, 1 Nov 2022 12:25:16 +0100 Subject: [PATCH 05/16] Authorization code Exchange support for private_key_jwt and tls_client_auth (#3) --- README.md | 71 +++++++++++++-- advancedauth/advancedauth.go | 26 ++++-- advancedauth/privatekeyjwt_test.go | 114 +++++++++++++++++++++++-- advancedauth/tls.go | 10 +-- advancedauth/tls_test.go | 89 +++++++++++++++++-- clientcredentials/clientcredentials.go | 4 +- oauth2.go | 26 ++++++ 7 files changed, 309 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 5a40c5982..c43193b40 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ import ( "context" "time" + "github.com/cloudentity/oauth2" "github.com/cloudentity/oauth2/advancedauth" "github.com/cloudentity/oauth2/clientcredentials" ) @@ -39,16 +40,46 @@ import ( cfg := clientcredentials.Config{ ClientID: "your client id", AuthStyle: oauth2.AuthStylePrivateKeyJWT, - PrivateKeyAuth: advancedauth.PrivateKeyAuth{ - Key: "your PEM encoded private key", - Alg: advancedauth.RS256, - Exp: 30 * time.Second, - }, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: "your PEM encoded private key", + Algorithm: advancedauth.RS256, + Exp: 30 * time.Second, + }, } token, err := cfg.Token(context.Background()) ``` +#### Authorization code + +```go +import ( + "context" + "time" + + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth" +) +``` + +```go + + cfg := oauth2.Config{ + ClientID: "your client id", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + }, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: "your PEM encoded private key", + Algorithm: advancedauth.RS256, + Exp: 30 * time.Second, + }, + Scopes: []string{"scope1", "scope2"}, + }, + + token, err := cfg.Exchange(context.Background(), "your authorization code") +``` + ### TLS Auth Both `tls_client_auth` and `self_signed_tls_client_auth` are handled with `TLSAuth` @@ -60,6 +91,7 @@ import ( "context" "time" + "github.com/cloudentity/oauth2" "github.com/cloudentity/oauth2/advancedauth" "github.com/cloudentity/oauth2/clientcredentials" ) @@ -78,6 +110,35 @@ import ( token, err := cfg.Token(context.Background()) ``` +#### Authorization code + +```go +import ( + "context" + "time" + + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth" +) +``` + +```go + + cfg := oauth2.Config{ + ClientID: "your client id", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStyleTLS, + }, + TLSAuth: advancedauth.TLSAuth{ + Key: "your certificate PEM encoded private key", + Certificate: "your PEM encoded TLS certificate", + }, + Scopes: []string{"scope1", "scope2"}, + }, + + token, err := cfg.Exchange(context.Background(), "your authorization code") +``` + ## Implementation This fork tries to limit changes to the original codebase to the minimum. diff --git a/advancedauth/advancedauth.go b/advancedauth/advancedauth.go index fd6f60d45..664834eaf 100644 --- a/advancedauth/advancedauth.go +++ b/advancedauth/advancedauth.go @@ -3,8 +3,6 @@ package advancedauth import ( "context" "net/url" - - "github.com/cloudentity/oauth2" ) type Algorithm string @@ -19,8 +17,20 @@ const ( ES512 Algorithm = "ES512" ) +type AuthStyle int + +const ( + // AuthStylePrivateKeyJWT sends a JWT assertion + // signed using the private key + // described in OpenID Connect Core + AuthStylePrivateKeyJWT AuthStyle = 3 + + // AuthStyleTLS + AuthStyleTLS AuthStyle = 4 +) + type Config struct { - AuthStyle oauth2.AuthStyle + AuthStyle AuthStyle ClientID string PrivateKeyAuth PrivateKeyAuth TLSAuth TLSAuth @@ -28,7 +38,7 @@ type Config struct { } func ExtendUrlValues(v url.Values, c Config) error { - if c.AuthStyle == oauth2.AuthStylePrivateKeyJWT { + if c.AuthStyle == AuthStylePrivateKeyJWT { jwtVals, err := privateKeyJWTAssertionVals(c) if err != nil { return err @@ -40,15 +50,15 @@ func ExtendUrlValues(v url.Values, c Config) error { } } } - if c.AuthStyle == oauth2.AuthStyleTLS { + if c.AuthStyle == AuthStyleTLS { v.Set("client_id", c.ClientID) } return nil } -func ExtendContext(ctx context.Context, c Config) (context.Context, error) { - if c.AuthStyle == oauth2.AuthStyleTLS { - return extendContextWithTLSClient(ctx, c) +func ExtendContext(ctx context.Context, httpClientContextKey interface{}, c Config) (context.Context, error) { + if c.AuthStyle == AuthStyleTLS { + return extendContextWithTLSClient(ctx, httpClientContextKey, c) } return ctx, nil } diff --git a/advancedauth/privatekeyjwt_test.go b/advancedauth/privatekeyjwt_test.go index a82e3d169..d75edf4d7 100644 --- a/advancedauth/privatekeyjwt_test.go +++ b/advancedauth/privatekeyjwt_test.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "net/http/httptest" - "net/url" "testing" "time" @@ -72,8 +71,7 @@ func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { PrivateKeyAuth: advancedauth.PrivateKeyAuth{ Key: privateKey, }, - Scopes: []string{"scope1", "scope2"}, - EndpointParams: url.Values{"audience": {"audience1"}}, + Scopes: []string{"scope1", "scope2"}, }, publicKey: rsaPubKey, }, @@ -86,8 +84,7 @@ func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { Key: privateECDSAKey, Algorithm: "ES256", }, - Scopes: []string{"scope1", "scope2"}, - EndpointParams: url.Values{"audience": {"audience1"}}, + Scopes: []string{"scope1", "scope2"}, }, publicKey: ecdsaPubKey, }, @@ -153,3 +150,110 @@ func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { } } + +func TestPrivateKeyJWT_Exchange(t *testing.T) { + rsaPubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey)) + if err != nil { + t.Error("could not parse rsa public key") + } + ecdsaPubKey, err := jwt.ParseECPublicKeyFromPEM([]byte(publicECDSAKey)) + if err != nil { + t.Error("could not parse ecdsa public key") + } + tcs := []struct { + title string + config oauth2.Config + publicKey interface{} + }{ + { + title: "RSA", + config: oauth2.Config{ + ClientID: "CLIENT_ID", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + }, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: privateKey, + }, + Scopes: []string{"scope1", "scope2"}, + }, + publicKey: rsaPubKey, + }, + { + title: "ECDSA", + config: oauth2.Config{ + ClientID: "CLIENT_ID", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + }, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: privateECDSAKey, + Algorithm: "ES256", + }, + Scopes: []string{"scope1", "scope2"}, + }, + publicKey: ecdsaPubKey, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + var serverURL string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectURL(tt, r, "/token") + expectHeader(tt, r, "Authorization", "") + expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") + expectFormParam(tt, r, "client_id", "") + expectFormParam(tt, r, "client_secret", "") + expectFormParam(tt, r, "code", "random") + expectFormParam(tt, r, "grant_type", "authorization_code") + expectFormParam(tt, r, "scope", "") + expectFormParam(tt, r, "client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + + assertion := r.FormValue("client_assertion") + claims := jwt.RegisteredClaims{} + token, err := jwt.ParseWithClaims(assertion, &claims, func(token *jwt.Token) (interface{}, error) { + return tc.publicKey, nil + }) + if err != nil { + tt.Errorf("could not parse assertion %+v", err) + } + if !token.Valid { + tt.Error("invalid assertion token") + } + + expectStringsEqual(tt, "CLIENT_ID", claims.Issuer) + expectStringsEqual(tt, "CLIENT_ID", claims.Subject) + + // uuid v4 like + expectTrue(tt, len(claims.ID) == 36) + + expectTrue(tt, time.Now().Unix() < claims.ExpiresAt.Unix()) + expectStringsEqual(tt, serverURL, claims.Audience[0]) + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + if err != nil { + tt.Errorf("could not write body") + } + })) + serverURL = ts.URL + defer ts.Close() + conf := tc.config + conf.Endpoint.TokenURL = ts.URL + "/token" + tok, err := conf.Exchange(context.Background(), "random") + if err != nil { + tt.Error(err) + } + expectAccessToken(tt, &oauth2.Token{ + AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", + TokenType: "bearer", + RefreshToken: "", + Expiry: time.Time{}, + }, tok) + }) + } + +} diff --git a/advancedauth/tls.go b/advancedauth/tls.go index e9ccd8374..ccf6ebfd3 100644 --- a/advancedauth/tls.go +++ b/advancedauth/tls.go @@ -5,8 +5,6 @@ import ( "crypto/tls" "errors" "net/http" - - "github.com/cloudentity/oauth2" ) type TLSAuth struct { @@ -16,7 +14,7 @@ type TLSAuth struct { Certificate string } -func extendContextWithTLSClient(ctx context.Context, c Config) (context.Context, error) { +func extendContextWithTLSClient(ctx context.Context, httpClientContextKey interface{}, c Config) (context.Context, error) { var ( hc *http.Client ok bool @@ -28,9 +26,9 @@ func extendContextWithTLSClient(ctx context.Context, c Config) (context.Context, ctx = context.Background() } - if ctx.Value(oauth2.HTTPClient) == nil { + if ctx.Value(httpClientContextKey) == nil { hc = http.DefaultClient - } else if hc, ok = ctx.Value(oauth2.HTTPClient).(*http.Client); !ok { + } else if hc, ok = ctx.Value(httpClientContextKey).(*http.Client); !ok { return nil, errors.New("client of type *http.Client required in context") } @@ -49,6 +47,6 @@ func extendContextWithTLSClient(ctx context.Context, c Config) (context.Context, tr.TLSClientConfig.Certificates = []tls.Certificate{cert} hc.Transport = tr - return context.WithValue(ctx, oauth2.HTTPClient, hc), nil + return context.WithValue(ctx, httpClientContextKey, hc), nil } diff --git a/advancedauth/tls_test.go b/advancedauth/tls_test.go index 1ac5d6674..d81b48a4e 100644 --- a/advancedauth/tls_test.go +++ b/advancedauth/tls_test.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "net/http" "net/http/httptest" - "net/url" "testing" "time" @@ -78,8 +77,7 @@ func TestTLS_ClientCredentials(t *testing.T) { Key: key, Certificate: cert, }, - Scopes: []string{"scope1", "scope2"}, - EndpointParams: url.Values{"audience": {"audience1"}}, + Scopes: []string{"scope1", "scope2"}, }, }, } @@ -95,6 +93,7 @@ func TestTLS_ClientCredentials(t *testing.T) { expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") expectFormParam(tt, r, "client_id", "CLIENT_ID") expectFormParam(tt, r, "client_secret", "") + expectFormParam(tt, r, "scope", "scope1 scope2") expectFormParam(tt, r, "grant_type", "client_credentials") cert := r.TLS.PeerCertificates[0] @@ -142,6 +141,86 @@ func TestTLS_ClientCredentials(t *testing.T) { } +func TestTLS_Exchange(t *testing.T) { + tcs := []struct { + title string + config oauth2.Config + }{ + { + title: "TLS", + config: oauth2.Config{ + ClientID: "CLIENT_ID", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStyleTLS, + }, + TLSAuth: advancedauth.TLSAuth{ + Key: key, + Certificate: cert, + }, + Scopes: []string{"scope1", "scope2"}, + }, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + var serverURL string + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectURL(tt, r, "/token") + expectHeader(tt, r, "Authorization", "") + expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") + expectFormParam(tt, r, "client_id", "CLIENT_ID") + expectFormParam(tt, r, "client_secret", "") + expectFormParam(tt, r, "scope", "") + expectFormParam(tt, r, "grant_type", "authorization_code") + + cert := r.TLS.PeerCertificates[0] + expectStringsEqual(tt, "Example-Root-CA", cert.Issuer.CommonName) + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err := w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + if err != nil { + tt.Errorf("could not write body") + } + })) + + ts.TLS = &tls.Config{ + ClientAuth: tls.RequestClientCert, + } + + ts.StartTLS() + serverURL = ts.URL + defer ts.Close() + conf := tc.config + conf.Endpoint.TokenURL = serverURL + "/token" + + _, err := conf.Exchange(context.Background(), "random") + // context.Background() will fail as the server cert is not trusted + // err == nil checks if there are no panics + if err == nil { + tt.Errorf("expected Token to fail with invalid server cert") + } + + client := ts.Client() + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + tok, err := conf.Exchange(ctx, "random") + if err != nil { + tt.Error(err) + } + + expectAccessToken(tt, &oauth2.Token{ + AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", + TokenType: "bearer", + RefreshToken: "", + Expiry: time.Time{}, + }, tok) + }) + } + +} + type fakeRoundTripper struct{} func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { @@ -226,12 +305,12 @@ func TestExtendContext(t *testing.T) { tc := tc t.Run(tc.title, func(tt *testing.T) { config := advancedauth.Config{ - AuthStyle: oauth2.AuthStyleTLS, + AuthStyle: advancedauth.AuthStyleTLS, ClientID: "random", TLSAuth: tc.auth, TokenURL: "random", } - ctx, err := advancedauth.ExtendContext(tc.ctx, config) + ctx, err := advancedauth.ExtendContext(tc.ctx, oauth2.HTTPClient, config) if tc.errorExpected && err == nil { tt.Errorf("expected error") } else if !tc.errorExpected && err != nil { diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 4d8a9b46e..913cb5465 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -107,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { if c.conf.AuthStyle > 2 { var err error cfg := advancedauth.Config{ - AuthStyle: c.conf.AuthStyle, + AuthStyle: advancedauth.AuthStyle(c.conf.AuthStyle), ClientID: c.conf.ClientID, PrivateKeyAuth: c.conf.PrivateKeyAuth, TLSAuth: c.conf.TLSAuth, @@ -116,7 +116,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { if err = advancedauth.ExtendUrlValues(v, cfg); err != nil { return nil, err } - if c.ctx, err = advancedauth.ExtendContext(c.ctx, cfg); err != nil { + if c.ctx, err = advancedauth.ExtendContext(c.ctx, oauth2.HTTPClient, cfg); err != nil { return nil, err } } diff --git a/oauth2.go b/oauth2.go index 38ec79e53..9d1cd1c73 100644 --- a/oauth2.go +++ b/oauth2.go @@ -17,6 +17,7 @@ import ( "strings" "sync" + "github.com/cloudentity/oauth2/advancedauth" "github.com/cloudentity/oauth2/internal" ) @@ -57,6 +58,14 @@ type Config struct { // Scope specifies optional requested permissions. Scopes []string + + // PrivateKeyAuth stores configuration options for private_key_jwt + // client authentication method described in OpenID Connect spec. + PrivateKeyAuth advancedauth.PrivateKeyAuth + + // TLSAuth stores the configuration options for tls_client_auth and self_signed_tls_client_auth + // client authentication methods described in RFC 8705 + TLSAuth advancedauth.TLSAuth } // A TokenSource is anything that can return a token. @@ -226,6 +235,23 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti if c.RedirectURL != "" { v.Set("redirect_uri", c.RedirectURL) } + // not client_secret nor auto_detect + if c.Endpoint.AuthStyle > 2 { + var err error + cfg := advancedauth.Config{ + AuthStyle: advancedauth.AuthStyle(c.Endpoint.AuthStyle), + ClientID: c.ClientID, + PrivateKeyAuth: c.PrivateKeyAuth, + TLSAuth: c.TLSAuth, + TokenURL: c.Endpoint.TokenURL, + } + if err = advancedauth.ExtendUrlValues(v, cfg); err != nil { + return nil, err + } + if ctx, err = advancedauth.ExtendContext(ctx, HTTPClient, cfg); err != nil { + return nil, err + } + } for _, opt := range opts { opt.setValue(v) } From 9db4f897345cf4d0e308ffb3630ec89dc1f34995 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Wed, 2 Nov 2022 19:28:31 +0100 Subject: [PATCH 06/16] PKCE utils (#4) --- README.md | 56 +++++++++++++ advancedauth/pkce/pkce.go | 120 +++++++++++++++++++++++++++ advancedauth/pkce_test.go | 167 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 343 insertions(+) create mode 100644 advancedauth/pkce/pkce.go create mode 100644 advancedauth/pkce_test.go diff --git a/README.md b/README.md index c43193b40..9ea479033 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ It extends the original library with additional authentication methods: - tls_client_auth - self_signed_tls_client_auth +Additionally, it also adds utility methods for easy use of PKCE. + ## Installation When using go modules you can run: @@ -139,6 +141,60 @@ import ( token, err := cfg.Exchange(context.Background(), "your authorization code") ``` +### PKCE + +```go +import ( + "context" + "time" + + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth/pkce" +) +``` + +Create `PKCE` with + +```go + p, err := pkce.New() +``` + +or, if you want to specify the code challenge method and verifier length + +```go + p, err := pkce.NewWithMethodVerifierLength(pkce.512, 84) +``` + +#### AuthCodeURL + +`PKCE` exposes few utility methods to ease creating `AuthCodeURL` + +You can use utility methods returning needed `AuthCodeOption`'s + +``` + url = conf.AuthCodeURL("state", p.AuthCodeURLOpts()...) +``` + +or, individual methods + +``` + url := conf.AuthCodeURL("state", p.ChallengeOpt(), p.MethodOpt()) +``` + +#### Exchange + +`PKCE` also exposes similar methods for `Exchange` + +```go + tok, err := conf.Exchange(context.Background(), "exchange-code", p.ExchangeOpts()...) +``` + +or, with individual methods + +```go + tok, err := conf.Exchange(context.Background(), "exchange-code", p.VerifierOpt(), p.MethodOpt()) +``` + ## Implementation This fork tries to limit changes to the original codebase to the minimum. diff --git a/advancedauth/pkce/pkce.go b/advancedauth/pkce/pkce.go new file mode 100644 index 000000000..bb44587b1 --- /dev/null +++ b/advancedauth/pkce/pkce.go @@ -0,0 +1,120 @@ +package pkce + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "errors" + "fmt" + "hash" + + "github.com/cloudentity/oauth2" +) + +type PKCE struct { + Method Method + Challenge string + Verifier string +} + +type Method string + +const ( + S256 Method = "S256" + S384 Method = "S384" + S512 Method = "S512" + // not recommended, use S256 + Plain Method = "plain" +) + +// https://www.rfc-editor.org/rfc/rfc7636#section-4.1 +const verifierDict = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + +func New() (PKCE, error) { + return NewWithMethodVerifierLength(S256, 64) +} + +func NewWithMethodVerifierLength(method Method, verifierLength int) (PKCE, error) { + var ( + verifier string + challenge string + err error + ) + if verifierLength < 43 || verifierLength > 128 { + // https://www.rfc-editor.org/rfc/rfc7636#section-4.1 + return PKCE{}, errors.New("verifier has to be between 43 and 128 chars long") + } + + if verifier, err = randomVerifer(verifierLength); err != nil { + return PKCE{}, err + } + if challenge, err = calculateChallenge(verifier, method); err != nil { + return PKCE{}, err + } + + return PKCE{ + Method: method, + Challenge: challenge, + Verifier: verifier, + }, nil +} + +func randomVerifer(length int) (string, error) { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for i, b := range bytes { + bytes[i] = verifierDict[b%byte(len(verifierDict))] + } + return string(bytes), nil +} + +func calculateChallenge(verifier string, method Method) (string, error) { + var ( + hasher hash.Hash + ) + switch method { + case Plain: + return verifier, nil + case S256: + hasher = sha256.New() + case S384: + hasher = sha512.New384() + case S512: + hasher = sha512.New() + } + if hasher != nil { + if _, err := hasher.Write([]byte(verifier)); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)), nil + } + return "", fmt.Errorf("invalid method %s", method) +} + +func (p *PKCE) ChallengeOpt() oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("code_challenge", p.Challenge) +} + +func (p *PKCE) MethodOpt() oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("code_challenge_method", string(p.Method)) +} + +func (p *PKCE) VerifierOpt() oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("code_verifier", p.Verifier) +} + +func (p *PKCE) AuthCodeURLOpts() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{ + p.ChallengeOpt(), p.MethodOpt(), + } +} + +func (p *PKCE) ExchangeOpts() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{ + p.VerifierOpt(), p.MethodOpt(), + } +} diff --git a/advancedauth/pkce_test.go b/advancedauth/pkce_test.go new file mode 100644 index 000000000..3c4f2599a --- /dev/null +++ b/advancedauth/pkce_test.go @@ -0,0 +1,167 @@ +package advancedauth_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/cloudentity/oauth2" + "github.com/cloudentity/oauth2/advancedauth/pkce" +) + +func TestPKCE_AuthorizationCodeFlow(t *testing.T) { + tcs := []struct { + title string + config oauth2.Config + publicKey interface{} + }{ + { + title: "pkce with client auth", + config: oauth2.Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStyleInParams, + }, + Scopes: []string{"scope1", "scope2"}, + }, + }, + { + title: "pkce without client auth", + config: oauth2.Config{ + ClientID: "CLIENT_ID", + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStyleInParams, + }, + Scopes: []string{"scope1", "scope2"}, + }, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + + p, _ := pkce.New() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectURL(tt, r, "/token") + expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") + expectFormParam(tt, r, "client_id", "CLIENT_ID") + if tc.config.ClientSecret != "" { + expectFormParam(tt, r, "client_secret", "CLIENT_SECRET") + } + expectFormParam(tt, r, "code", "exchange-code") + expectFormParam(tt, r, "grant_type", "authorization_code") + expectFormParam(tt, r, "code_verifier", p.Verifier) + expectFormParam(tt, r, "code_challenge_method", string(p.Method)) + expectFormParam(tt, r, "scope", "") + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err := w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + if err != nil { + tt.Errorf("could not write body") + } + })) + defer ts.Close() + conf := tc.config + conf.Endpoint.TokenURL = ts.URL + "/token" + + expectedAuthCodeURL := fmt.Sprintf( + "?client_id=%s&code_challenge=%s&code_challenge_method=%s&response_type=code&scope=%s&state=state", + tc.config.ClientID, p.Challenge, p.Method, strings.Join(tc.config.Scopes, "+"), + ) + + url := conf.AuthCodeURL("state", p.ChallengeOpt(), p.MethodOpt()) + expectStringsEqual(tt, expectedAuthCodeURL, url) + + url = conf.AuthCodeURL("state", p.AuthCodeURLOpts()...) + expectStringsEqual(tt, expectedAuthCodeURL, url) + + tok, err := conf.Exchange(context.Background(), "exchange-code", p.ExchangeOpts()...) + if err != nil { + tt.Error(err) + } + expectAccessToken(tt, &oauth2.Token{ + AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", + TokenType: "bearer", + RefreshToken: "", + Expiry: time.Time{}, + }, tok) + }) + } + +} + +func TestPKCE(t *testing.T) { + tcs := []struct { + title string + method pkce.Method + verifierLen int + error bool + }{ + { + title: "simple", + }, + { + title: "plain", + method: pkce.S512, + verifierLen: 50, + }, + { + title: "S512", + method: pkce.S512, + }, + { + title: "invalid method", + method: "some random stuff", + verifierLen: 50, + error: true, + }, + { + title: "verifier too short", + method: "S256", + verifierLen: 20, + error: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.title, func(tt *testing.T) { + var ( + p pkce.PKCE + err error + ) + + if tc.method != "" && tc.verifierLen != 0 { + p, err = pkce.NewWithMethodVerifierLength(tc.method, tc.verifierLen) + } else { + p, err = pkce.New() + expectStringsEqual(tt, "S256", string(p.Method)) + } + + if err != nil && !tc.error { + tt.Fatalf("could not generate PKCE, got %+v", err) + } else if err == nil && tc.error { + tt.Fatalf("expected error, got nil") + } else if err == nil { + if tc.verifierLen == 0 { + if len(p.Verifier) != 64 { + tt.Fatalf("expected verifier of length 64") + } + } else if len(p.Verifier) != tc.verifierLen { + tt.Fatalf("expected verifier of length %d", tc.verifierLen) + } + + if tc.method == pkce.Plain { + expectStringsEqual(tt, p.Verifier, p.Challenge) + } + } + }) + } +} From c332bdc615dcac2b698ac4a5e79124c0543052e1 Mon Sep 17 00:00:00 2001 From: Konrad Holowinski Date: Wed, 9 Nov 2022 15:55:16 +0100 Subject: [PATCH 07/16] fix module name --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index f7680d536..5c92bf05e 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/cloudentity/oauth2 +module golang.org/x/oauth2 go 1.17 From 1415fd246fd06e72f917e56b7feba7dca9a9d9dd Mon Sep 17 00:00:00 2001 From: Konrad Holowinski Date: Wed, 9 Nov 2022 16:00:29 +0100 Subject: [PATCH 08/16] set module name to golang.org/x/oauth2 --- README.md | 24 +++++++++---------- advancedauth/pkce/pkce.go | 2 +- advancedauth/pkce_test.go | 4 ++-- advancedauth/privatekeyjwt_test.go | 6 ++--- advancedauth/tls_test.go | 6 ++--- advancedauth/utils_test.go | 2 +- amazon/amazon.go | 2 +- authhandler/authhandler.go | 2 +- authhandler/authhandler_test.go | 2 +- bitbucket/bitbucket.go | 2 +- cern/cern.go | 4 ++-- clientcredentials/clientcredentials.go | 8 +++---- clientcredentials/clientcredentials_test.go | 2 +- endpoints/endpoints.go | 2 +- endpoints/endpoints_test.go | 2 +- example_test.go | 2 +- facebook/facebook.go | 4 ++-- fitbit/fitbit.go | 4 ++-- foursquare/foursquare.go | 4 ++-- github/github.go | 4 ++-- gitlab/gitlab.go | 4 ++-- google/appengine.go | 2 +- google/appengine_gen1.go | 2 +- google/appengine_gen2_flex.go | 2 +- google/default.go | 4 ++-- google/doc.go | 4 ++-- google/downscope/downscoping.go | 2 +- google/downscope/downscoping_test.go | 2 +- google/downscope/tokenbroker_test.go | 6 ++--- google/error.go | 2 +- google/error_test.go | 2 +- google/example_test.go | 6 ++--- google/google.go | 6 ++--- google/internal/externalaccount/aws.go | 2 +- .../externalaccount/basecredentials.go | 2 +- .../externalaccount/basecredentials_test.go | 2 +- google/internal/externalaccount/clientauth.go | 2 +- .../externalaccount/clientauth_test.go | 2 +- .../internal/externalaccount/impersonate.go | 2 +- .../internal/externalaccount/sts_exchange.go | 2 +- .../externalaccount/sts_exchange_test.go | 2 +- .../internal/externalaccount/urlcredsource.go | 2 +- google/jwt.go | 6 ++--- google/jwt_test.go | 2 +- google/sdk.go | 2 +- heroku/heroku.go | 4 ++-- hipchat/hipchat.go | 6 ++--- instagram/instagram.go | 4 ++-- internal/token.go | 2 +- jira/jira.go | 2 +- jira/jira_test.go | 4 ++-- jws/jws.go | 4 ++-- jwt/example_test.go | 2 +- jwt/jwt.go | 6 ++--- jwt/jwt_test.go | 4 ++-- kakao/kakao.go | 4 ++-- linkedin/linkedin.go | 4 ++-- mailchimp/mailchimp.go | 4 ++-- mailru/mailru.go | 4 ++-- mediamath/mediamath.go | 4 ++-- microsoft/microsoft.go | 4 ++-- nokiahealth/nokiahealth.go | 2 +- oauth2.go | 8 +++---- oauth2_test.go | 2 +- odnoklassniki/odnoklassniki.go | 4 ++-- paypal/paypal.go | 4 ++-- slack/slack.go | 4 ++-- spotify/spotify.go | 4 ++-- stackoverflow/stackoverflow.go | 4 ++-- token.go | 2 +- transport.go | 2 +- twitch/twitch.go | 4 ++-- uber/uber.go | 4 ++-- vk/vk.go | 4 ++-- yahoo/yahoo.go | 4 ++-- yandex/yandex.go | 4 ++-- 76 files changed, 139 insertions(+), 139 deletions(-) diff --git a/README.md b/README.md index 9ea479033..920a8c969 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,9 @@ import ( "context" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" - "github.com/cloudentity/oauth2/clientcredentials" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" + "golang.org/x/oauth2/clientcredentials" ) ``` @@ -59,8 +59,8 @@ import ( "context" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" ) ``` @@ -93,9 +93,9 @@ import ( "context" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" - "github.com/cloudentity/oauth2/clientcredentials" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" + "golang.org/x/oauth2/clientcredentials" ) ``` @@ -119,8 +119,8 @@ import ( "context" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" ) ``` @@ -148,8 +148,8 @@ import ( "context" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth/pkce" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth/pkce" ) ``` diff --git a/advancedauth/pkce/pkce.go b/advancedauth/pkce/pkce.go index bb44587b1..a735cfc33 100644 --- a/advancedauth/pkce/pkce.go +++ b/advancedauth/pkce/pkce.go @@ -9,7 +9,7 @@ import ( "fmt" "hash" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) type PKCE struct { diff --git a/advancedauth/pkce_test.go b/advancedauth/pkce_test.go index 3c4f2599a..3f76e3156 100644 --- a/advancedauth/pkce_test.go +++ b/advancedauth/pkce_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth/pkce" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth/pkce" ) func TestPKCE_AuthorizationCodeFlow(t *testing.T) { diff --git a/advancedauth/privatekeyjwt_test.go b/advancedauth/privatekeyjwt_test.go index d75edf4d7..b7c3c9dee 100644 --- a/advancedauth/privatekeyjwt_test.go +++ b/advancedauth/privatekeyjwt_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" - "github.com/cloudentity/oauth2/clientcredentials" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" + "golang.org/x/oauth2/clientcredentials" "github.com/golang-jwt/jwt/v4" ) diff --git a/advancedauth/tls_test.go b/advancedauth/tls_test.go index d81b48a4e..08fe8892f 100644 --- a/advancedauth/tls_test.go +++ b/advancedauth/tls_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" - "github.com/cloudentity/oauth2/clientcredentials" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" + "golang.org/x/oauth2/clientcredentials" ) const ( diff --git a/advancedauth/utils_test.go b/advancedauth/utils_test.go index c7a0db34f..171e86ff9 100644 --- a/advancedauth/utils_test.go +++ b/advancedauth/utils_test.go @@ -5,7 +5,7 @@ import ( "net/http" "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) func expectHeader(t *testing.T, r *http.Request, header string, expected string) { diff --git a/amazon/amazon.go b/amazon/amazon.go index 18e254e1f..d21da11af 100644 --- a/amazon/amazon.go +++ b/amazon/amazon.go @@ -6,7 +6,7 @@ package amazon import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Amazon's OAuth 2.0 endpoint. diff --git a/authhandler/authhandler.go b/authhandler/authhandler.go index e60255ec9..9bc6cd7bc 100644 --- a/authhandler/authhandler.go +++ b/authhandler/authhandler.go @@ -10,7 +10,7 @@ import ( "context" "errors" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) const ( diff --git a/authhandler/authhandler_test.go b/authhandler/authhandler_test.go index 365c51be7..ad1980492 100644 --- a/authhandler/authhandler_test.go +++ b/authhandler/authhandler_test.go @@ -11,7 +11,7 @@ import ( "net/http/httptest" "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) func TestTokenExchange_Success(t *testing.T) { diff --git a/bitbucket/bitbucket.go b/bitbucket/bitbucket.go index 401c1ccb3..44af1f1a9 100644 --- a/bitbucket/bitbucket.go +++ b/bitbucket/bitbucket.go @@ -6,7 +6,7 @@ package bitbucket import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Bitbucket's OAuth 2.0 endpoint. diff --git a/cern/cern.go b/cern/cern.go index 0364d10bb..8be718078 100644 --- a/cern/cern.go +++ b/cern/cern.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package cern provides constants for using OAuth2 to access CERN services. -package cern // import "github.com/cloudentity/oauth2/cern" +package cern // import "golang.org/x/oauth2/cern" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is CERN's OAuth 2.0 endpoint. diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 913cb5465..0b06c5da2 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -11,7 +11,7 @@ // server. // // See https://tools.ietf.org/html/rfc6749#section-4.4 -package clientcredentials // import "github.com/cloudentity/oauth2/clientcredentials" +package clientcredentials // import "golang.org/x/oauth2/clientcredentials" import ( "context" @@ -20,9 +20,9 @@ import ( "net/url" "strings" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/advancedauth" - "github.com/cloudentity/oauth2/internal" + "golang.org/x/oauth2" + "golang.org/x/oauth2/advancedauth" + "golang.org/x/oauth2/internal" ) // Config describes a 2-legged OAuth2 flow, with both the diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 3f8216ede..02a1c89a8 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -13,7 +13,7 @@ import ( "net/url" "testing" - "github.com/cloudentity/oauth2/internal" + "golang.org/x/oauth2/internal" ) func newConf(serverURL string) *Config { diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go index 2db328dec..7cc37c876 100644 --- a/endpoints/endpoints.go +++ b/endpoints/endpoints.go @@ -8,7 +8,7 @@ package endpoints import ( "strings" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Amazon is the endpoint for Amazon. diff --git a/endpoints/endpoints_test.go b/endpoints/endpoints_test.go index 92486678b..4ffa31429 100644 --- a/endpoints/endpoints_test.go +++ b/endpoints/endpoints_test.go @@ -7,7 +7,7 @@ package endpoints import ( "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) func TestAWSCognitoEndpoint(t *testing.T) { diff --git a/example_test.go b/example_test.go index 6fe828fc1..fc2f793b2 100644 --- a/example_test.go +++ b/example_test.go @@ -11,7 +11,7 @@ import ( "net/http" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) func ExampleConfig() { diff --git a/facebook/facebook.go b/facebook/facebook.go index baa452706..b0054e387 100644 --- a/facebook/facebook.go +++ b/facebook/facebook.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package facebook provides constants for using OAuth2 to access Facebook. -package facebook // import "github.com/cloudentity/oauth2/facebook" +package facebook // import "golang.org/x/oauth2/facebook" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Facebook's OAuth 2.0 endpoint. diff --git a/fitbit/fitbit.go b/fitbit/fitbit.go index 9170a7bf2..b31b82aca 100644 --- a/fitbit/fitbit.go +++ b/fitbit/fitbit.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package fitbit provides constants for using OAuth2 to access the Fitbit API. -package fitbit // import "github.com/cloudentity/oauth2/fitbit" +package fitbit // import "golang.org/x/oauth2/fitbit" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is the Fitbit API's OAuth 2.0 endpoint. diff --git a/foursquare/foursquare.go b/foursquare/foursquare.go index 7533cf154..d2fa09902 100644 --- a/foursquare/foursquare.go +++ b/foursquare/foursquare.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package foursquare provides constants for using OAuth2 to access Foursquare. -package foursquare // import "github.com/cloudentity/oauth2/foursquare" +package foursquare // import "golang.org/x/oauth2/foursquare" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Foursquare's OAuth 2.0 endpoint. diff --git a/github/github.go b/github/github.go index 0b01897f4..f2978015b 100644 --- a/github/github.go +++ b/github/github.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package github provides constants for using OAuth2 to access Github. -package github // import "github.com/cloudentity/oauth2/github" +package github // import "golang.org/x/oauth2/github" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Github's OAuth 2.0 endpoint. diff --git a/gitlab/gitlab.go b/gitlab/gitlab.go index 3e8e5cb5e..1231d75ac 100644 --- a/gitlab/gitlab.go +++ b/gitlab/gitlab.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package gitlab provides constants for using OAuth2 to access GitLab. -package gitlab // import "github.com/cloudentity/oauth2/gitlab" +package gitlab // import "golang.org/x/oauth2/gitlab" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is GitLab's OAuth 2.0 endpoint. diff --git a/google/appengine.go b/google/appengine.go index 971506468..feb1157b1 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -8,7 +8,7 @@ import ( "context" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Set at init time by appengine_gen1.go. If nil, we're not on App Engine standard first generation (<= Go 1.9) or App Engine flexible. diff --git a/google/appengine_gen1.go b/google/appengine_gen1.go index 0c77add1f..16c6c6b90 100644 --- a/google/appengine_gen1.go +++ b/google/appengine_gen1.go @@ -15,7 +15,7 @@ import ( "strings" "sync" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" "google.golang.org/appengine" ) diff --git a/google/appengine_gen2_flex.go b/google/appengine_gen2_flex.go index 420eaa1ac..a7e27b3d2 100644 --- a/google/appengine_gen2_flex.go +++ b/google/appengine_gen2_flex.go @@ -14,7 +14,7 @@ import ( "log" "sync" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) var logOnce sync.Once // only spam about deprecation once diff --git a/google/default.go b/google/default.go index 57af55560..7ed02cd41 100644 --- a/google/default.go +++ b/google/default.go @@ -15,8 +15,8 @@ import ( "runtime" "cloud.google.com/go/compute/metadata" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/authhandler" + "golang.org/x/oauth2" + "golang.org/x/oauth2/authhandler" ) // Credentials holds Google credentials, including "Application Default Credentials". diff --git a/google/doc.go b/google/doc.go index a25eaad26..b3e7bc85c 100644 --- a/google/doc.go +++ b/google/doc.go @@ -17,7 +17,7 @@ // // # OAuth2 Configs // -// Two functions in this package return github.com/cloudentity/oauth2.Config values from Google credential +// Two functions in this package return golang.org/x/oauth2.Config values from Google credential // data. Google supports two JSON formats for OAuth2 credentials: one is handled by ConfigFromJSON, // the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or // create an http.Client. @@ -81,4 +81,4 @@ // same as the one obtained from the oauth2.Config returned from ConfigFromJSON or // JWTConfigFromJSON, but the Credentials may contain additional information // that is useful is some circumstances. -package google // import "github.com/cloudentity/oauth2/google" +package google // import "golang.org/x/oauth2/google" diff --git a/google/downscope/downscoping.go b/google/downscope/downscoping.go index 6b5b19b70..3d4b5532d 100644 --- a/google/downscope/downscoping.go +++ b/google/downscope/downscoping.go @@ -44,7 +44,7 @@ import ( "net/url" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) var ( diff --git a/google/downscope/downscoping_test.go b/google/downscope/downscoping_test.go index 06c15c684..d5adda19c 100644 --- a/google/downscope/downscoping_test.go +++ b/google/downscope/downscoping_test.go @@ -11,7 +11,7 @@ import ( "net/http/httptest" "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) var ( diff --git a/google/downscope/tokenbroker_test.go b/google/downscope/tokenbroker_test.go index 25e7263cd..cb168785f 100644 --- a/google/downscope/tokenbroker_test.go +++ b/google/downscope/tokenbroker_test.go @@ -8,10 +8,10 @@ import ( "context" "fmt" - "github.com/cloudentity/oauth2/google" + "golang.org/x/oauth2/google" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/google/downscope" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google/downscope" ) func ExampleNewTokenSource() { diff --git a/google/error.go b/google/error.go index c0143d91c..d84dd0047 100644 --- a/google/error.go +++ b/google/error.go @@ -7,7 +7,7 @@ package google import ( "errors" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // AuthenticationError indicates there was an error in the authentication flow. diff --git a/google/error_test.go b/google/error_test.go index 4a9e18fb1..cd60e9118 100644 --- a/google/error_test.go +++ b/google/error_test.go @@ -8,7 +8,7 @@ import ( "net/http" "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) func TestAuthenticationError_Temporary(t *testing.T) { diff --git a/google/example_test.go b/google/example_test.go index 568caac7d..3fc9cad3f 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -11,9 +11,9 @@ import ( "log" "net/http" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/google" - "github.com/cloudentity/oauth2/jwt" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/jwt" ) func ExampleDefaultClient() { diff --git a/google/google.go b/google/google.go index 95829f331..8df0c493e 100644 --- a/google/google.go +++ b/google/google.go @@ -14,9 +14,9 @@ import ( "time" "cloud.google.com/go/compute/metadata" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/google/internal/externalaccount" - "github.com/cloudentity/oauth2/jwt" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google/internal/externalaccount" + "golang.org/x/oauth2/jwt" ) // Endpoint is Google's OAuth 2.0 default endpoint. diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index 62fd5b327..e917195d5 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) type awsSecurityCredentials struct { diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index e7ba68bd3..9fc35535e 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -14,7 +14,7 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // now aliases time.Now for testing diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 207d4ffb3..05e0127f0 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) const ( diff --git a/google/internal/externalaccount/clientauth.go b/google/internal/externalaccount/clientauth.go index fff0c44db..99987ce29 100644 --- a/google/internal/externalaccount/clientauth.go +++ b/google/internal/externalaccount/clientauth.go @@ -9,7 +9,7 @@ import ( "net/http" "net/url" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. diff --git a/google/internal/externalaccount/clientauth_test.go b/google/internal/externalaccount/clientauth_test.go index bd9138bfa..bfb339d06 100644 --- a/google/internal/externalaccount/clientauth_test.go +++ b/google/internal/externalaccount/clientauth_test.go @@ -10,7 +10,7 @@ import ( "reflect" "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) var clientID = "rbrgnognrhongo3bi4gb9ghg9g" diff --git a/google/internal/externalaccount/impersonate.go b/google/internal/externalaccount/impersonate.go index db97a7764..54c8f209f 100644 --- a/google/internal/externalaccount/impersonate.go +++ b/google/internal/externalaccount/impersonate.go @@ -14,7 +14,7 @@ import ( "net/http" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // generateAccesstokenReq is used for service account impersonation diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go index a262f462a..e6fcae5fc 100644 --- a/google/internal/externalaccount/sts_exchange.go +++ b/google/internal/externalaccount/sts_exchange.go @@ -15,7 +15,7 @@ import ( "strconv" "strings" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // exchangeToken performs an oauth2 token exchange with the provided endpoint. diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go index 747b0d69d..df4d5ff4e 100644 --- a/google/internal/externalaccount/sts_exchange_test.go +++ b/google/internal/externalaccount/sts_exchange_test.go @@ -13,7 +13,7 @@ import ( "net/url" "testing" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) var auth = clientAuthentication{ diff --git a/google/internal/externalaccount/urlcredsource.go b/google/internal/externalaccount/urlcredsource.go index 247548845..16dca6541 100644 --- a/google/internal/externalaccount/urlcredsource.go +++ b/google/internal/externalaccount/urlcredsource.go @@ -13,7 +13,7 @@ import ( "io/ioutil" "net/http" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) type urlCredentialSource struct { diff --git a/google/jwt.go b/google/jwt.go index ede6136af..e89e6ae17 100644 --- a/google/jwt.go +++ b/google/jwt.go @@ -10,9 +10,9 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/internal" - "github.com/cloudentity/oauth2/jws" + "golang.org/x/oauth2" + "golang.org/x/oauth2/internal" + "golang.org/x/oauth2/jws" ) // JWTAccessTokenSourceFromJSON uses a Google Developers service account JSON diff --git a/google/jwt_test.go b/google/jwt_test.go index 2dbb8f2b5..5890ae9a7 100644 --- a/google/jwt_test.go +++ b/google/jwt_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "github.com/cloudentity/oauth2/jws" + "golang.org/x/oauth2/jws" ) var ( diff --git a/google/sdk.go b/google/sdk.go index a6f0c0895..456224bc7 100644 --- a/google/sdk.go +++ b/google/sdk.go @@ -19,7 +19,7 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) type sdkCredentials struct { diff --git a/heroku/heroku.go b/heroku/heroku.go index a42bac9d9..5b4fdb890 100644 --- a/heroku/heroku.go +++ b/heroku/heroku.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package heroku provides constants for using OAuth2 to access Heroku. -package heroku // import "github.com/cloudentity/oauth2/heroku" +package heroku // import "golang.org/x/oauth2/heroku" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Heroku's OAuth 2.0 endpoint. diff --git a/hipchat/hipchat.go b/hipchat/hipchat.go index 8732cd70d..594fe072c 100644 --- a/hipchat/hipchat.go +++ b/hipchat/hipchat.go @@ -3,14 +3,14 @@ // license that can be found in the LICENSE file. // Package hipchat provides constants for using OAuth2 to access HipChat. -package hipchat // import "github.com/cloudentity/oauth2/hipchat" +package hipchat // import "golang.org/x/oauth2/hipchat" import ( "encoding/json" "errors" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/clientcredentials" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" ) // Endpoint is HipChat's OAuth 2.0 endpoint. diff --git a/instagram/instagram.go b/instagram/instagram.go index db96eb691..75a74ebb9 100644 --- a/instagram/instagram.go +++ b/instagram/instagram.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package instagram provides constants for using OAuth2 to access Instagram. -package instagram // import "github.com/cloudentity/oauth2/instagram" +package instagram // import "golang.org/x/oauth2/instagram" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Instagram's OAuth 2.0 endpoint. diff --git a/internal/token.go b/internal/token.go index 4a8bee700..355c38696 100644 --- a/internal/token.go +++ b/internal/token.go @@ -102,7 +102,7 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { // Endpoint.AuthStyle. func RegisterBrokenAuthHeaderProvider(tokenURL string) {} -// AuthStyle is a copy of the github.com/cloudentity/oauth2 package's AuthStyle type. +// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. type AuthStyle int const ( diff --git a/jira/jira.go b/jira/jira.go index fecccf1e9..814656e9e 100644 --- a/jira/jira.go +++ b/jira/jira.go @@ -19,7 +19,7 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // ClaimSet contains information about the JWT signature according diff --git a/jira/jira_test.go b/jira/jira_test.go index 47d1d91f6..07f6a6314 100644 --- a/jira/jira_test.go +++ b/jira/jira_test.go @@ -13,8 +13,8 @@ import ( "strings" "testing" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/jws" + "golang.org/x/oauth2" + "golang.org/x/oauth2/jws" ) func TestJWTFetch_JSONResponse(t *testing.T) { diff --git a/jws/jws.go b/jws/jws.go index 00157a291..95015648b 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -4,7 +4,7 @@ // Package jws provides a partial implementation // of JSON Web Signature encoding and decoding. -// It exists to support the github.com/cloudentity/oauth2 package. +// It exists to support the golang.org/x/oauth2 package. // // See RFC 7515. // @@ -12,7 +12,7 @@ // removed in the future. It exists for internal use only. // Please switch to another JWS package or copy this package into your own // source tree. -package jws // import "github.com/cloudentity/oauth2/jws" +package jws // import "golang.org/x/oauth2/jws" import ( "bytes" diff --git a/jwt/example_test.go b/jwt/example_test.go index fe99c3fa8..58503d80d 100644 --- a/jwt/example_test.go +++ b/jwt/example_test.go @@ -7,7 +7,7 @@ package jwt_test import ( "context" - "github.com/cloudentity/oauth2/jwt" + "golang.org/x/oauth2/jwt" ) func ExampleJWTConfig() { diff --git a/jwt/jwt.go b/jwt/jwt.go index 9b4794edb..b2bf18298 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -19,9 +19,9 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/internal" - "github.com/cloudentity/oauth2/jws" + "golang.org/x/oauth2" + "golang.org/x/oauth2/internal" + "golang.org/x/oauth2/jws" ) var ( diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index f9e1913c5..9772dc520 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -15,8 +15,8 @@ import ( "strings" "testing" - "github.com/cloudentity/oauth2" - "github.com/cloudentity/oauth2/jws" + "golang.org/x/oauth2" + "golang.org/x/oauth2/jws" ) var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- diff --git a/kakao/kakao.go b/kakao/kakao.go index f1a1c3a87..6d211260c 100644 --- a/kakao/kakao.go +++ b/kakao/kakao.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package kakao provides constants for using OAuth2 to access Kakao. -package kakao // import "github.com/cloudentity/oauth2/kakao" +package kakao // import "golang.org/x/oauth2/kakao" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Kakao's OAuth 2.0 endpoint. diff --git a/linkedin/linkedin.go b/linkedin/linkedin.go index 33af5f04b..d3972771c 100644 --- a/linkedin/linkedin.go +++ b/linkedin/linkedin.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package linkedin provides constants for using OAuth2 to access LinkedIn. -package linkedin // import "github.com/cloudentity/oauth2/linkedin" +package linkedin // import "golang.org/x/oauth2/linkedin" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is LinkedIn's OAuth 2.0 endpoint. diff --git a/mailchimp/mailchimp.go b/mailchimp/mailchimp.go index 208db920f..647787ec6 100644 --- a/mailchimp/mailchimp.go +++ b/mailchimp/mailchimp.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package mailchimp provides constants for using OAuth2 to access MailChimp. -package mailchimp // import "github.com/cloudentity/oauth2/mailchimp" +package mailchimp // import "golang.org/x/oauth2/mailchimp" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is MailChimp's OAuth 2.0 endpoint. diff --git a/mailru/mailru.go b/mailru/mailru.go index f51dd297e..dddd9dd0f 100644 --- a/mailru/mailru.go +++ b/mailru/mailru.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package mailru provides constants for using OAuth2 to access Mail.Ru. -package mailru // import "github.com/cloudentity/oauth2/mailru" +package mailru // import "golang.org/x/oauth2/mailru" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Mail.Ru's OAuth 2.0 endpoint. diff --git a/mediamath/mediamath.go b/mediamath/mediamath.go index e44c64f55..3ebce5da1 100644 --- a/mediamath/mediamath.go +++ b/mediamath/mediamath.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package mediamath provides constants for using OAuth2 to access MediaMath. -package mediamath // import "github.com/cloudentity/oauth2/mediamath" +package mediamath // import "golang.org/x/oauth2/mediamath" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is MediaMath's OAuth 2.0 endpoint for production. diff --git a/microsoft/microsoft.go b/microsoft/microsoft.go index 5e13b612c..3ffbc57a6 100644 --- a/microsoft/microsoft.go +++ b/microsoft/microsoft.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package microsoft provides constants for using OAuth2 to access Windows Live ID. -package microsoft // import "github.com/cloudentity/oauth2/microsoft" +package microsoft // import "golang.org/x/oauth2/microsoft" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // LiveConnectEndpoint is Windows's Live ID OAuth 2.0 endpoint. diff --git a/nokiahealth/nokiahealth.go b/nokiahealth/nokiahealth.go index e112b0fff..c181ccd0f 100644 --- a/nokiahealth/nokiahealth.go +++ b/nokiahealth/nokiahealth.go @@ -6,7 +6,7 @@ package nokiahealth import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Nokia Health Mate's OAuth 2.0 endpoint. diff --git a/oauth2.go b/oauth2.go index 9d1cd1c73..1081cf61e 100644 --- a/oauth2.go +++ b/oauth2.go @@ -6,7 +6,7 @@ // OAuth2 authorized and authenticated HTTP requests, // as specified in RFC 6749. // It can additionally grant authorization with Bearer JWT. -package oauth2 // import "github.com/cloudentity/oauth2" +package oauth2 // import "golang.org/x/oauth2" import ( "bytes" @@ -17,8 +17,8 @@ import ( "strings" "sync" - "github.com/cloudentity/oauth2/advancedauth" - "github.com/cloudentity/oauth2/internal" + "golang.org/x/oauth2/advancedauth" + "golang.org/x/oauth2/internal" ) // NoContext is the default context you should supply if not using @@ -38,7 +38,7 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. // For the client credentials 2-legged OAuth2 flow, see the clientcredentials -// package (https://github.com/cloudentity/oauth2/clientcredentials). +// package (https://golang.org/x/oauth2/clientcredentials). type Config struct { // ClientID is the application's ID. ClientID string diff --git a/oauth2_test.go b/oauth2_test.go index a95af6367..b7975e166 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "github.com/cloudentity/oauth2/internal" + "golang.org/x/oauth2/internal" ) type mockTransport struct { diff --git a/odnoklassniki/odnoklassniki.go b/odnoklassniki/odnoklassniki.go index cc79ce70c..c0d093ccc 100644 --- a/odnoklassniki/odnoklassniki.go +++ b/odnoklassniki/odnoklassniki.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package odnoklassniki provides constants for using OAuth2 to access Odnoklassniki. -package odnoklassniki // import "github.com/cloudentity/oauth2/odnoklassniki" +package odnoklassniki // import "golang.org/x/oauth2/odnoklassniki" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Odnoklassniki's OAuth 2.0 endpoint. diff --git a/paypal/paypal.go b/paypal/paypal.go index 31bebed31..2e713c53c 100644 --- a/paypal/paypal.go +++ b/paypal/paypal.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package paypal provides constants for using OAuth2 to access PayPal. -package paypal // import "github.com/cloudentity/oauth2/paypal" +package paypal // import "golang.org/x/oauth2/paypal" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is PayPal's OAuth 2.0 endpoint in live (production) environment. diff --git a/slack/slack.go b/slack/slack.go index a980ea35d..593d2f607 100644 --- a/slack/slack.go +++ b/slack/slack.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package slack provides constants for using OAuth2 to access Slack. -package slack // import "github.com/cloudentity/oauth2/slack" +package slack // import "golang.org/x/oauth2/slack" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Slack's OAuth 2.0 endpoint. diff --git a/spotify/spotify.go b/spotify/spotify.go index c8d49a467..c75416c00 100644 --- a/spotify/spotify.go +++ b/spotify/spotify.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package spotify provides constants for using OAuth2 to access Spotify. -package spotify // import "github.com/cloudentity/oauth2/spotify" +package spotify // import "golang.org/x/oauth2/spotify" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Spotify's OAuth 2.0 endpoint. diff --git a/stackoverflow/stackoverflow.go b/stackoverflow/stackoverflow.go index 6bed97880..82711f777 100644 --- a/stackoverflow/stackoverflow.go +++ b/stackoverflow/stackoverflow.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package stackoverflow provides constants for using OAuth2 to access Stack Overflow. -package stackoverflow // import "github.com/cloudentity/oauth2/stackoverflow" +package stackoverflow // import "golang.org/x/oauth2/stackoverflow" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Stack Overflow's OAuth 2.0 endpoint. diff --git a/token.go b/token.go index 2dbb204c7..822720341 100644 --- a/token.go +++ b/token.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "github.com/cloudentity/oauth2/internal" + "golang.org/x/oauth2/internal" ) // expiryDelta determines how earlier a token should be considered diff --git a/transport.go b/transport.go index 0f86580f9..90657915f 100644 --- a/transport.go +++ b/transport.go @@ -63,7 +63,7 @@ var cancelOnce sync.Once // Deprecated: use contexts for cancellation instead. func (t *Transport) CancelRequest(req *http.Request) { cancelOnce.Do(func() { - log.Printf("deprecated: github.com/cloudentity/oauth2: Transport.CancelRequest no longer does anything; use contexts") + log.Printf("deprecated: golang.org/x/oauth2: Transport.CancelRequest no longer does anything; use contexts") }) } diff --git a/twitch/twitch.go b/twitch/twitch.go index d825b5d5c..0838e7c15 100644 --- a/twitch/twitch.go +++ b/twitch/twitch.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package twitch provides constants for using OAuth2 to access Twitch. -package twitch // import "github.com/cloudentity/oauth2/twitch" +package twitch // import "golang.org/x/oauth2/twitch" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Twitch's OAuth 2.0 endpoint. diff --git a/uber/uber.go b/uber/uber.go index b654784d5..5520a6455 100644 --- a/uber/uber.go +++ b/uber/uber.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package uber provides constants for using OAuth2 to access Uber. -package uber // import "github.com/cloudentity/oauth2/uber" +package uber // import "golang.org/x/oauth2/uber" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Uber's OAuth 2.0 endpoint. diff --git a/vk/vk.go b/vk/vk.go index 54f013acb..bd8e15948 100644 --- a/vk/vk.go +++ b/vk/vk.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package vk provides constants for using OAuth2 to access VK.com. -package vk // import "github.com/cloudentity/oauth2/vk" +package vk // import "golang.org/x/oauth2/vk" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is VK's OAuth 2.0 endpoint. diff --git a/yahoo/yahoo.go b/yahoo/yahoo.go index 6fe05f69b..9fa78a23c 100644 --- a/yahoo/yahoo.go +++ b/yahoo/yahoo.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package yahoo provides constants for using OAuth2 to access Yahoo. -package yahoo // import "github.com/cloudentity/oauth2/yahoo" +package yahoo // import "golang.org/x/oauth2/yahoo" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is Yahoo's OAuth 2.0 endpoint. diff --git a/yandex/yandex.go b/yandex/yandex.go index d48f6f5dc..5ebf666d2 100644 --- a/yandex/yandex.go +++ b/yandex/yandex.go @@ -3,10 +3,10 @@ // license that can be found in the LICENSE file. // Package yandex provides constants for using OAuth2 to access Yandex APIs. -package yandex // import "github.com/cloudentity/oauth2/yandex" +package yandex // import "golang.org/x/oauth2/yandex" import ( - "github.com/cloudentity/oauth2" + "golang.org/x/oauth2" ) // Endpoint is the Yandex OAuth 2.0 endpoint. From d090050bbcfcaa4d3e96b134b4da01215d2260aa Mon Sep 17 00:00:00 2001 From: Konrad Holowinski Date: Wed, 7 Dec 2022 14:59:03 +0100 Subject: [PATCH 09/16] add token URL as additional audience --- advancedauth/privatekeyjwt.go | 2 +- advancedauth/privatekeyjwt_test.go | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/advancedauth/privatekeyjwt.go b/advancedauth/privatekeyjwt.go index c5f658210..bdb24db08 100644 --- a/advancedauth/privatekeyjwt.go +++ b/advancedauth/privatekeyjwt.go @@ -45,7 +45,7 @@ func privateKeyJWTAssertionVals(c Config) (url.Values, error) { claims := &jwt.RegisteredClaims{ Issuer: c.ClientID, Subject: c.ClientID, - Audience: []string{strings.TrimSuffix(c.TokenURL, "/token")}, + Audience: []string{c.TokenURL, strings.TrimSuffix(c.TokenURL, "/token")}, ID: jti, ExpiresAt: jwt.NewNumericDate(time.Now().Add(exp)), } diff --git a/advancedauth/privatekeyjwt_test.go b/advancedauth/privatekeyjwt_test.go index b7c3c9dee..8ee756b63 100644 --- a/advancedauth/privatekeyjwt_test.go +++ b/advancedauth/privatekeyjwt_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" "golang.org/x/oauth2" "golang.org/x/oauth2/advancedauth" "golang.org/x/oauth2/clientcredentials" - "github.com/golang-jwt/jwt/v4" ) const ( @@ -124,7 +124,8 @@ func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { expectTrue(tt, len(claims.ID) == 36) expectTrue(tt, time.Now().Unix() < claims.ExpiresAt.Unix()) - expectStringsEqual(tt, serverURL, claims.Audience[0]) + expectStringsEqual(tt, serverURL+"/token", claims.Audience[0]) + expectStringsEqual(tt, serverURL, claims.Audience[1]) w.Header().Set("Content-Type", "application/x-www-form-urlencoded") _, err = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) @@ -231,7 +232,8 @@ func TestPrivateKeyJWT_Exchange(t *testing.T) { expectTrue(tt, len(claims.ID) == 36) expectTrue(tt, time.Now().Unix() < claims.ExpiresAt.Unix()) - expectStringsEqual(tt, serverURL, claims.Audience[0]) + expectStringsEqual(tt, serverURL+"/token", claims.Audience[0]) + expectStringsEqual(tt, serverURL, claims.Audience[1]) w.Header().Set("Content-Type", "application/x-www-form-urlencoded") _, err = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) From 91d8250131624affffb9f1df6d266d34873b9208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Thu, 2 Feb 2023 14:52:53 +0100 Subject: [PATCH 10/16] Drop TLS support (#5) --- README.md | 57 ++--- advancedauth/advancedauth.go | 9 - advancedauth/tls.go | 52 ---- advancedauth/tls_test.go | 333 ------------------------- clientcredentials/clientcredentials.go | 8 - oauth2.go | 8 - 6 files changed, 19 insertions(+), 448 deletions(-) delete mode 100644 advancedauth/tls.go delete mode 100644 advancedauth/tls_test.go diff --git a/README.md b/README.md index 920a8c969..028ae2a2e 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,6 @@ This repo is a drop-in replacement of `golang.org/x/oauth2` It extends the original library with additional authentication methods: - private_key_jwt -- tls_client_auth -- self_signed_tls_client_auth Additionally, it also adds utility methods for easy use of PKCE. @@ -84,9 +82,14 @@ import ( ### TLS Auth -Both `tls_client_auth` and `self_signed_tls_client_auth` are handled with `TLSAuth` +If you want to use `tls_client_auth` or `self_signed_tls_client_auth` there is no dedicated +configuration for the client certificate and key. -#### Client credentials +You should create an appropriate `*http.Client` and pass it in the context. + +One thing this library does is that it adds an AuthStyle `AuthStyleTLS` which appropriately sends the `client_id` but skips the `client_secret`. + +Example: ```go import ( @@ -100,45 +103,23 @@ import ( ``` ```go - cfg := clientcredentials.Config{ - ClientID: "your client id", - AuthStyle: oauth2.AuthStyleTLS, - TLSAuth: advancedauth.TLSAuth{ - Key: "your certificate PEM encoded private key", - Certificate: "your PEM encoded TLS certificate", - }, - } - token, err := cfg.Token(context.Background()) -``` + // ... generate cert -#### Authorization code - -```go -import ( - "context" - "time" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/advancedauth" -) -``` - -```go + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + }, + } - cfg := oauth2.Config{ + cfg := clientcredentials.Config{ ClientID: "your client id", - Endpoint: oauth2.Endpoint{ - AuthStyle: oauth2.AuthStyleTLS, - }, - TLSAuth: advancedauth.TLSAuth{ - Key: "your certificate PEM encoded private key", - Certificate: "your PEM encoded TLS certificate", - }, - Scopes: []string{"scope1", "scope2"}, - }, + AuthStyle: oauth2.AuthStyleTLS, + } - token, err := cfg.Exchange(context.Background(), "your authorization code") + token, err := cfg.Token(context.WithValue(context.Background(), oauth2.HTTPClient, client)) ``` ### PKCE diff --git a/advancedauth/advancedauth.go b/advancedauth/advancedauth.go index 664834eaf..07a00d669 100644 --- a/advancedauth/advancedauth.go +++ b/advancedauth/advancedauth.go @@ -1,7 +1,6 @@ package advancedauth import ( - "context" "net/url" ) @@ -33,7 +32,6 @@ type Config struct { AuthStyle AuthStyle ClientID string PrivateKeyAuth PrivateKeyAuth - TLSAuth TLSAuth TokenURL string } @@ -55,10 +53,3 @@ func ExtendUrlValues(v url.Values, c Config) error { } return nil } - -func ExtendContext(ctx context.Context, httpClientContextKey interface{}, c Config) (context.Context, error) { - if c.AuthStyle == AuthStyleTLS { - return extendContextWithTLSClient(ctx, httpClientContextKey, c) - } - return ctx, nil -} diff --git a/advancedauth/tls.go b/advancedauth/tls.go deleted file mode 100644 index ccf6ebfd3..000000000 --- a/advancedauth/tls.go +++ /dev/null @@ -1,52 +0,0 @@ -package advancedauth - -import ( - "context" - "crypto/tls" - "errors" - "net/http" -) - -type TLSAuth struct { - // Key is the private key for client TLS certificate - Key string - // Certificate is the client TLS certificate - Certificate string -} - -func extendContextWithTLSClient(ctx context.Context, httpClientContextKey interface{}, c Config) (context.Context, error) { - var ( - hc *http.Client - ok bool - cert tls.Certificate - err error - tr *http.Transport - ) - if ctx == nil { - ctx = context.Background() - } - - if ctx.Value(httpClientContextKey) == nil { - hc = http.DefaultClient - } else if hc, ok = ctx.Value(httpClientContextKey).(*http.Client); !ok { - return nil, errors.New("client of type *http.Client required in context") - } - - if cert, err = tls.X509KeyPair([]byte(c.TLSAuth.Certificate), []byte(c.TLSAuth.Key)); err != nil { - return nil, err - } - - if hc.Transport == nil { - tr = &http.Transport{} - } else if tr, ok = hc.Transport.(*http.Transport); !ok { - return nil, errors.New("transport of type *http.Transport required in context") - } - if tr.TLSClientConfig == nil { - tr.TLSClientConfig = &tls.Config{} - } - tr.TLSClientConfig.Certificates = []tls.Certificate{cert} - hc.Transport = tr - - return context.WithValue(ctx, httpClientContextKey, hc), nil - -} diff --git a/advancedauth/tls_test.go b/advancedauth/tls_test.go deleted file mode 100644 index 08fe8892f..000000000 --- a/advancedauth/tls_test.go +++ /dev/null @@ -1,333 +0,0 @@ -package advancedauth_test - -import ( - "context" - "crypto/tls" - "net/http" - "net/http/httptest" - "testing" - "time" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/advancedauth" - "golang.org/x/oauth2/clientcredentials" -) - -const ( - key = `-----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC0uhESy4URdqwo -8Hbus5UjdxQom0zQj7jw4bcZ2Z4X0HLJbmbDZdwIaoOWfSjYu9VYPkE04/+KnBOh -XMpA8DfcyS+XVPPTAEFI7KH9RF7BTMjSxB32Huwz9hMHqiPxJx1R+dTSWSC61+GX -Dq+cLHGeQq4Cqxxf0nnGmgpnT26GtiG/QZzE0IdlxaK68BzFk3syNzVFE8Om6yzx -ET7L5/p6igFrj22enjbYimtcSuHM2k16n0MSipBL2v1scheifGN0P+po118IRuX2 -mU8WH5Z8eyInWf857sNEHuFoCkuegJFVkzkuzxZz/F+cT1Znfq0x17ssnL9SFDk4 -XpyNKTqPAgMBAAECggEAULaMq30zV8JNTxddtmuDnswut5fsLXUSnpnf4W6cOXyB -1040HO4f365aSFprZKg2tutOyeVNmkTsS3OabHgcKsG7PHXXUxPZFE2CZw8i1meJ -hP/LdcEHsokipJiq5qeWY6cVEkB16pxBhuorKa97qreS6WQsDut8MWNYZB1Iemaa -HjioQZ7SpUUUyr3XNuvoaPViymGou6DYLaIMg0zklOrfigu1Qb4XdtWtbdi3AWcr -dVNO/N8Y19pJGqpJZ0FlqT/G8es10prAJGPAy4O/RxsLEfOSlZHe1Oj5V63B5h6R -KPwzSRM03gqHG0qruhr2seQN2UvJSRJNz3a2q7siGQKBgQDsXvcohxXoVkv2yvq4 -D9QmQxU3/zHPZhnFNpZ9p3a4AHvmTFyTErTPrZn+QW/l9VvyKGctezR9/SMTLmsQ -dz8Pnbqoukp2Vo/zNK1HEf3Iy5/lVZtd4ErfFCKpWYkNEXX43RQ2qvNt/XkkuIIg -mijoKxBfiwKD8sGB2B8owHCi6wKBgQDDvCglc1yPQ3dzEcaMOoABKWdH7Q72Xgjr -rpmO5lATn6kvcwgAjf/EEIGSQVjoY3zhOZ4J/eV7G6NTg9sRVhcWtkt1UtVv1BwE -Cg4P6W7hCg8GF8Egh/dYtarx19juZkXk5HNSe0PEgrpbjzdxx0s/2HE1JwziVa3q -qJFV4gd17QKBgQCS81dlctZD46LGg9rro6uZPgtrDNTCxA8xdIaLCBneuy5MNx02 -smKG2r7qO3R92tSW8Fd1ByvTSBUOT8VwLzKdWso5K9gvShGkehNgI+dLdoyp31cA -PflORw5liqyR21Ekrw1qD03YC8XM9oiwDCdyb5N2Us31im6TcvGsPDfKkQKBgQCF -Ok0ZMKyP1xw29qJuUGNQZx4llvXYO6lWwkFDQwC+Wq6N3X5U4lJ04cdQBaq+gvk9 -VDp+EpNgeC9zaQxzgGW2z94MvZUJyRZIqY9oxTrzciVHwGN0ARgbCYyRkJnXq0Vn -xxe3zK8T0ueF6rWSfFR74Jct1qauaCM41gQWsQLjAQKBgGfnF99nLe1iI4AZgLIQ -nYgCV65/bmbgX5gkMbDMxZzZYNWg15YuB5Ir+cf20pCwO5EmoLpn7KGpEeED4+/z -2PZrF4bcjmEhYT5O2Y1Wn1oB84uug9c+ME7yiU30g1FttURZuLtzUxASFP2o0l7r -zbSntKWbvm2qk39YKulrEnoh ------END PRIVATE KEY-----` - cert = `-----BEGIN CERTIFICATE----- -MIIDHTCCAgUCFE+Ha5QgryApfoCjSX564o0JoGYIMA0GCSqGSIb3DQEBCwUAMCcx -CzAJBgNVBAYTAlVTMRgwFgYDVQQDDA9FeGFtcGxlLVJvb3QtQ0EwIBcNMjIxMDMx -MTgxODQ0WhgPMjA3NzA4MDMxODE4NDRaMG0xCzAJBgNVBAYTAlVTMRIwEAYDVQQI -DAlZb3VyU3RhdGUxETAPBgNVBAcMCFlvdXJDaXR5MR0wGwYDVQQKDBRFeGFtcGxl -LUNlcnRpZmljYXRlczEYMBYGA1UEAwwPbG9jYWxob3N0LmxvY2FsMIIBIjANBgkq -hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtLoREsuFEXasKPB27rOVI3cUKJtM0I+4 -8OG3GdmeF9ByyW5mw2XcCGqDln0o2LvVWD5BNOP/ipwToVzKQPA33Mkvl1Tz0wBB -SOyh/URewUzI0sQd9h7sM/YTB6oj8ScdUfnU0lkgutfhlw6vnCxxnkKuAqscX9J5 -xpoKZ09uhrYhv0GcxNCHZcWiuvAcxZN7Mjc1RRPDpuss8RE+y+f6eooBa49tnp42 -2IprXErhzNpNep9DEoqQS9r9bHIXonxjdD/qaNdfCEbl9plPFh+WfHsiJ1n/Oe7D -RB7haApLnoCRVZM5Ls8Wc/xfnE9WZ36tMde7LJy/UhQ5OF6cjSk6jwIDAQABMA0G -CSqGSIb3DQEBCwUAA4IBAQCBeRGIRS2MljdbgExv5KEND4OhEj2kuuES1zzTQjgs -EO6G3RlFRU9dFz9WDsLSeegY/4Y8BwR6kA3IpmLVnfmn4odWHhLv+JCDo7TG+R6c -3JnHbLuimcMLnGVVdUzAxQz09bNxYhCqUEla/ji0GeSxg8j8ofxtE7qihODV5dQv -gx3Ef/WxZTy08hd8pKxA8dg/VzechNRngFpINXUnGsX699pSoPWfHQoyZprvWjE7 -QDac6VgTzy/KPfaf9vi3MiXJyjJOuGO3+SL1PhR712qRGg9Y+kccNUlL4OfrLJpm -qobZlvUYUfAYcyJVtjas3vPoQHVCcbq7hdbso5FrLyPK ------END CERTIFICATE-----` -) - -func TestTLS_ClientCredentials(t *testing.T) { - tcs := []struct { - title string - config clientcredentials.Config - }{ - { - title: "TLS", - config: clientcredentials.Config{ - ClientID: "CLIENT_ID", - AuthStyle: oauth2.AuthStyleTLS, - TLSAuth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - Scopes: []string{"scope1", "scope2"}, - }, - }, - } - - for _, tc := range tcs { - tc := tc - t.Run(tc.title, func(tt *testing.T) { - var serverURL string - - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - expectURL(tt, r, "/token") - expectHeader(tt, r, "Authorization", "") - expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") - expectFormParam(tt, r, "client_id", "CLIENT_ID") - expectFormParam(tt, r, "client_secret", "") - expectFormParam(tt, r, "scope", "scope1 scope2") - expectFormParam(tt, r, "grant_type", "client_credentials") - - cert := r.TLS.PeerCertificates[0] - expectStringsEqual(tt, "Example-Root-CA", cert.Issuer.CommonName) - - w.Header().Set("Content-Type", "application/x-www-form-urlencoded") - _, err := w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) - if err != nil { - tt.Errorf("could not write body") - } - })) - - ts.TLS = &tls.Config{ - ClientAuth: tls.RequestClientCert, - } - - ts.StartTLS() - serverURL = ts.URL - defer ts.Close() - conf := tc.config - conf.TokenURL = serverURL + "/token" - - _, err := conf.Token(context.Background()) - // context.Background() will fail as the server cert is not trusted - // err == nil checks if there are no panics - if err == nil { - tt.Errorf("expected Token to fail with invalid server cert") - } - - client := ts.Client() - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) - tok, err := conf.Token(ctx) - if err != nil { - tt.Error(err) - } - - expectAccessToken(tt, &oauth2.Token{ - AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", - TokenType: "bearer", - RefreshToken: "", - Expiry: time.Time{}, - }, tok) - }) - } - -} - -func TestTLS_Exchange(t *testing.T) { - tcs := []struct { - title string - config oauth2.Config - }{ - { - title: "TLS", - config: oauth2.Config{ - ClientID: "CLIENT_ID", - Endpoint: oauth2.Endpoint{ - AuthStyle: oauth2.AuthStyleTLS, - }, - TLSAuth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - Scopes: []string{"scope1", "scope2"}, - }, - }, - } - - for _, tc := range tcs { - tc := tc - t.Run(tc.title, func(tt *testing.T) { - var serverURL string - - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - expectURL(tt, r, "/token") - expectHeader(tt, r, "Authorization", "") - expectHeader(tt, r, "Content-Type", "application/x-www-form-urlencoded") - expectFormParam(tt, r, "client_id", "CLIENT_ID") - expectFormParam(tt, r, "client_secret", "") - expectFormParam(tt, r, "scope", "") - expectFormParam(tt, r, "grant_type", "authorization_code") - - cert := r.TLS.PeerCertificates[0] - expectStringsEqual(tt, "Example-Root-CA", cert.Issuer.CommonName) - - w.Header().Set("Content-Type", "application/x-www-form-urlencoded") - _, err := w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) - if err != nil { - tt.Errorf("could not write body") - } - })) - - ts.TLS = &tls.Config{ - ClientAuth: tls.RequestClientCert, - } - - ts.StartTLS() - serverURL = ts.URL - defer ts.Close() - conf := tc.config - conf.Endpoint.TokenURL = serverURL + "/token" - - _, err := conf.Exchange(context.Background(), "random") - // context.Background() will fail as the server cert is not trusted - // err == nil checks if there are no panics - if err == nil { - tt.Errorf("expected Token to fail with invalid server cert") - } - - client := ts.Client() - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) - tok, err := conf.Exchange(ctx, "random") - if err != nil { - tt.Error(err) - } - - expectAccessToken(tt, &oauth2.Token{ - AccessToken: "90d64460d14870c08c81352a05dedd3465940a7c", - TokenType: "bearer", - RefreshToken: "", - Expiry: time.Time{}, - }, tok) - }) - } - -} - -type fakeRoundTripper struct{} - -func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - return nil, nil -} - -func TestExtendContext(t *testing.T) { - - tcs := []struct { - title string - ctx context.Context - errorExpected bool - auth advancedauth.TLSAuth - assertTransport func(ttt *testing.T, t *http.Transport) - }{ - { - title: "background context", - ctx: context.Background(), - errorExpected: false, - auth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - }, - { - title: "invalid cert", - ctx: context.Background(), - errorExpected: true, - auth: advancedauth.TLSAuth{ - Key: key, - Certificate: "random", - }, - }, - { - title: "non *http.Client client", - ctx: context.WithValue(context.Background(), oauth2.HTTPClient, struct{}{}), - errorExpected: true, - auth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - }, - { - title: "non *http.Transport transport", - ctx: context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{ - Transport: &fakeRoundTripper{}, - }), - errorExpected: true, - auth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - }, - { - title: "no transport configured", - ctx: context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{}), - errorExpected: false, - auth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - }, - { - title: "configured transport", - ctx: context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{ - Transport: &http.Transport{ - IdleConnTimeout: 10 * time.Second, - }, - }), - errorExpected: false, - auth: advancedauth.TLSAuth{ - Key: key, - Certificate: cert, - }, - assertTransport: func(ttt *testing.T, tr *http.Transport) { - expectTrue(ttt, tr.IdleConnTimeout == 10*time.Second) - }, - }, - } - - for _, tc := range tcs { - tc := tc - t.Run(tc.title, func(tt *testing.T) { - config := advancedauth.Config{ - AuthStyle: advancedauth.AuthStyleTLS, - ClientID: "random", - TLSAuth: tc.auth, - TokenURL: "random", - } - ctx, err := advancedauth.ExtendContext(tc.ctx, oauth2.HTTPClient, config) - if tc.errorExpected && err == nil { - tt.Errorf("expected error") - } else if !tc.errorExpected && err != nil { - tt.Fatalf("unexpected error %+v", err) - } else if !tc.errorExpected && err == nil { - c := ctx.Value(oauth2.HTTPClient) - expectTrue(tt, c != nil) - hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client) - expectTrue(tt, ok) - tr, ok := hc.Transport.(*http.Transport) - expectTrue(tt, ok) - certs := tr.TLSClientConfig.Certificates - expectTrue(tt, len(certs) == 1) - if tc.assertTransport != nil { - tc.assertTransport(tt, tr) - } - } - }) - } -} diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 0b06c5da2..eda6ec0ca 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -52,10 +52,6 @@ type Config struct { // PrivateKeyAuth stores configuration options for private_key_jwt // client authentication method described in OpenID Connect spec. PrivateKeyAuth advancedauth.PrivateKeyAuth - - // TLSAuth stores the configuration options for tls_client_auth and self_signed_tls_client_auth - // client authentication methods described in RFC 8705 - TLSAuth advancedauth.TLSAuth } // Token uses client credentials to retrieve a token. @@ -110,15 +106,11 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { AuthStyle: advancedauth.AuthStyle(c.conf.AuthStyle), ClientID: c.conf.ClientID, PrivateKeyAuth: c.conf.PrivateKeyAuth, - TLSAuth: c.conf.TLSAuth, TokenURL: c.conf.TokenURL, } if err = advancedauth.ExtendUrlValues(v, cfg); err != nil { return nil, err } - if c.ctx, err = advancedauth.ExtendContext(c.ctx, oauth2.HTTPClient, cfg); err != nil { - return nil, err - } } for k, p := range c.conf.EndpointParams { // Allow grant_type to be overridden to allow interoperability with diff --git a/oauth2.go b/oauth2.go index 1081cf61e..5cb7ca175 100644 --- a/oauth2.go +++ b/oauth2.go @@ -62,10 +62,6 @@ type Config struct { // PrivateKeyAuth stores configuration options for private_key_jwt // client authentication method described in OpenID Connect spec. PrivateKeyAuth advancedauth.PrivateKeyAuth - - // TLSAuth stores the configuration options for tls_client_auth and self_signed_tls_client_auth - // client authentication methods described in RFC 8705 - TLSAuth advancedauth.TLSAuth } // A TokenSource is anything that can return a token. @@ -242,15 +238,11 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti AuthStyle: advancedauth.AuthStyle(c.Endpoint.AuthStyle), ClientID: c.ClientID, PrivateKeyAuth: c.PrivateKeyAuth, - TLSAuth: c.TLSAuth, TokenURL: c.Endpoint.TokenURL, } if err = advancedauth.ExtendUrlValues(v, cfg); err != nil { return nil, err } - if ctx, err = advancedauth.ExtendContext(ctx, HTTPClient, cfg); err != nil { - return nil, err - } } for _, opt := range opts { opt.setValue(v) From 50fec225c7fec7d106f688539c2d4eeb0b4a1cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Fri, 24 Feb 2023 13:19:56 +0100 Subject: [PATCH 11/16] support custom audience for private_key_jwt (#7) --- advancedauth/privatekeyjwt.go | 10 +++++++++- advancedauth/privatekeyjwt_test.go | 24 ++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/advancedauth/privatekeyjwt.go b/advancedauth/privatekeyjwt.go index bdb24db08..61146d1b1 100644 --- a/advancedauth/privatekeyjwt.go +++ b/advancedauth/privatekeyjwt.go @@ -19,6 +19,8 @@ type PrivateKeyAuth struct { Algorithm Algorithm // Exp defines how long client_assertion is valid for - default 30 seconds Exp time.Duration + // Audience holds the intended recipients of the client_assertion + Audience []string } func privateKeyJWTAssertionVals(c Config) (url.Values, error) { @@ -42,10 +44,16 @@ func privateKeyJWTAssertionVals(c Config) (url.Values, error) { exp = 30 * time.Second } + audience := []string{c.TokenURL, strings.TrimSuffix(c.TokenURL, "/token")} + + if len(c.PrivateKeyAuth.Audience) > 0 { + audience = c.PrivateKeyAuth.Audience + } + claims := &jwt.RegisteredClaims{ Issuer: c.ClientID, Subject: c.ClientID, - Audience: []string{c.TokenURL, strings.TrimSuffix(c.TokenURL, "/token")}, + Audience: audience, ID: jti, ExpiresAt: jwt.NewNumericDate(time.Now().Add(exp)), } diff --git a/advancedauth/privatekeyjwt_test.go b/advancedauth/privatekeyjwt_test.go index 8ee756b63..4379df4d5 100644 --- a/advancedauth/privatekeyjwt_test.go +++ b/advancedauth/privatekeyjwt_test.go @@ -88,6 +88,20 @@ func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { }, publicKey: ecdsaPubKey, }, + { + title: "ECDSA with custom audience", + config: clientcredentials.Config{ + ClientID: "CLIENT_ID", + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + PrivateKeyAuth: advancedauth.PrivateKeyAuth{ + Key: privateECDSAKey, + Algorithm: "ES256", + Audience: []string{"https://example.com/audience"}, + }, + Scopes: []string{"scope1", "scope2"}, + }, + publicKey: ecdsaPubKey, + }, } for _, tc := range tcs { @@ -124,8 +138,14 @@ func TestPrivateKeyJWT_ClientCredentials(t *testing.T) { expectTrue(tt, len(claims.ID) == 36) expectTrue(tt, time.Now().Unix() < claims.ExpiresAt.Unix()) - expectStringsEqual(tt, serverURL+"/token", claims.Audience[0]) - expectStringsEqual(tt, serverURL, claims.Audience[1]) + + if len(tc.config.PrivateKeyAuth.Audience) > 0 { + expectTrue(tt, len(claims.Audience) == 1) + expectStringsEqual(tt, tc.config.PrivateKeyAuth.Audience[0], claims.Audience[0]) + } else { + expectStringsEqual(tt, serverURL+"/token", claims.Audience[0]) + expectStringsEqual(tt, serverURL, claims.Audience[1]) + } w.Header().Set("Content-Type", "application/x-www-form-urlencoded") _, err = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) From e3afa875b9939ac27ad74b32244ecfa9005e9e1b Mon Sep 17 00:00:00 2001 From: Mateusz Bilski Date: Wed, 16 Aug 2023 13:11:31 +0200 Subject: [PATCH 12/16] Sync fork (#8) * Sync * Sync --- google/appengine_gen1.go | 1 - google/appengine_gen2_flex.go | 1 - google/default.go | 37 +- google/doc.go | 65 ++- google/google.go | 15 +- google/internal/externalaccount/aws.go | 105 +++- google/internal/externalaccount/aws_test.go | 531 ++++++++++++++++-- .../externalaccount/basecredentials.go | 36 +- .../externalaccount/basecredentials_test.go | 135 ----- .../externalaccount/impersonate_test.go | 4 +- oauth2.go | 31 + token.go | 10 + 12 files changed, 701 insertions(+), 270 deletions(-) diff --git a/google/appengine_gen1.go b/google/appengine_gen1.go index 16c6c6b90..e61587945 100644 --- a/google/appengine_gen1.go +++ b/google/appengine_gen1.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build appengine -// +build appengine // This file applies to App Engine first generation runtimes (<= Go 1.9). diff --git a/google/appengine_gen2_flex.go b/google/appengine_gen2_flex.go index a7e27b3d2..9c79aa0a0 100644 --- a/google/appengine_gen2_flex.go +++ b/google/appengine_gen2_flex.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !appengine -// +build !appengine // This file applies to App Engine second generation runtimes (>= Go 1.11) and App Engine flexible. diff --git a/google/default.go b/google/default.go index 7ed02cd41..2cf71f0f9 100644 --- a/google/default.go +++ b/google/default.go @@ -8,17 +8,19 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" "net/http" "os" "path/filepath" "runtime" + "time" "cloud.google.com/go/compute/metadata" "golang.org/x/oauth2" "golang.org/x/oauth2/authhandler" ) +const adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc" + // Credentials holds Google credentials, including "Application Default Credentials". // For more details, see: // https://developers.google.com/accounts/docs/application-default-credentials @@ -62,6 +64,18 @@ type CredentialsParams struct { // PKCE is used to support PKCE flow. Optional for 3LO flow. PKCE *authhandler.PKCEParams + + // The OAuth2 TokenURL default override. This value overrides the default TokenURL, + // unless explicitly specified by the credentials config file. Optional. + TokenURL string + + // EarlyTokenRefresh is the amount of time before a token expires that a new + // token will be preemptively fetched. If unset the default value is 10 + // seconds. + // + // Note: This option is currently only respected when using credentials + // fetched from the GCE metadata server. + EarlyTokenRefresh time.Duration } func (params CredentialsParams) deepCopy() CredentialsParams { @@ -127,17 +141,15 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar // Second, try a well-known file. filename := wellKnownFile() - if creds, err := readCredentialsFile(ctx, filename, params); err == nil { - return creds, nil - } else if !os.IsNotExist(err) { - return nil, fmt.Errorf("google: error getting credentials using well-known file (%v): %v", filename, err) + if b, err := os.ReadFile(filename); err == nil { + return CredentialsFromJSONWithParams(ctx, b, params) } // Third, if we're on a Google App Engine standard first generation runtime (<= Go 1.9) // use those credentials. App Engine standard second generation runtimes (>= Go 1.11) // and App Engine flexible use ComputeTokenSource and the metadata server. if appengineTokenFunc != nil { - return &DefaultCredentials{ + return &Credentials{ ProjectID: appengineAppIDFunc(ctx), TokenSource: AppEngineTokenSource(ctx, params.Scopes...), }, nil @@ -147,15 +159,14 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar // or App Engine flexible, use the metadata server. if metadata.OnGCE() { id, _ := metadata.ProjectID() - return &DefaultCredentials{ + return &Credentials{ ProjectID: id, - TokenSource: ComputeTokenSource("", params.Scopes...), + TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), }, nil } // None are found; return helpful error. - const url = "https://developers.google.com/accounts/docs/application-default-credentials" - return nil, fmt.Errorf("google: could not find default credentials. See %v for more information.", url) + return nil, fmt.Errorf("google: could not find default credentials. See %v for more information", adcSetupURL) } // FindDefaultCredentials invokes FindDefaultCredentialsWithParams with the specified scopes. @@ -194,7 +205,7 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params return nil, err } ts = newErrWrappingTokenSource(ts) - return &DefaultCredentials{ + return &Credentials{ ProjectID: f.ProjectID, TokenSource: ts, JSON: jsonData, @@ -216,8 +227,8 @@ func wellKnownFile() string { return filepath.Join(guessUnixHomeDir(), ".config", "gcloud", f) } -func readCredentialsFile(ctx context.Context, filename string, params CredentialsParams) (*DefaultCredentials, error) { - b, err := ioutil.ReadFile(filename) +func readCredentialsFile(ctx context.Context, filename string, params CredentialsParams) (*Credentials, error) { + b, err := os.ReadFile(filename) if err != nil { return nil, err } diff --git a/google/doc.go b/google/doc.go index b3e7bc85c..ca717634a 100644 --- a/google/doc.go +++ b/google/doc.go @@ -26,7 +26,7 @@ // // Using workload identity federation, your application can access Google Cloud // resources from Amazon Web Services (AWS), Microsoft Azure or any identity -// provider that supports OpenID Connect (OIDC). +// provider that supports OpenID Connect (OIDC) or SAML 2.0. // Traditionally, applications running outside Google Cloud have used service // account keys to access Google Cloud resources. Using identity federation, // you can allow your workload to impersonate a service account. @@ -36,26 +36,75 @@ // Follow the detailed instructions on how to configure Workload Identity Federation // in various platforms: // -// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws -// Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure -// OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc +// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#aws +// Microsoft Azure: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#azure +// OIDC identity provider: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#oidc +// SAML 2.0 identity provider: https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#saml // // For OIDC and SAML providers, the library can retrieve tokens in three ways: // from a local file location (file-sourced credentials), from a server // (URL-sourced credentials), or from a local executable (executable-sourced // credentials). // For file-sourced credentials, a background process needs to be continuously -// refreshing the file location with a new OIDC token prior to expiration. +// refreshing the file location with a new OIDC/SAML token prior to expiration. // For tokens with one hour lifetimes, the token needs to be updated in the file // every hour. The token can be stored directly as plain text or in JSON format. // For URL-sourced credentials, a local server needs to host a GET endpoint to -// return the OIDC token. The response can be in plain text or JSON. +// return the OIDC/SAML token. The response can be in plain text or JSON. // Additional required request headers can also be specified. // For executable-sourced credentials, an application needs to be available to -// output the OIDC token and other information in a JSON format. +// output the OIDC/SAML token and other information in a JSON format. // For more information on how these work (and how to implement // executable-sourced credentials), please check out: -// https://cloud.google.com/iam/docs/using-workload-identity-federation#oidc +// https://cloud.google.com/iam/docs/workload-identity-federation-with-other-providers#create_a_credential_configuration +// +// Note that this library does not perform any validation on the token_url, token_info_url, +// or service_account_impersonation_url fields of the credential configuration. +// It is not recommended to use a credential configuration that you did not generate with +// the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain. +// +// # Workforce Identity Federation +// +// Workforce identity federation lets you use an external identity provider (IdP) to +// authenticate and authorize a workforce—a group of users, such as employees, partners, +// and contractors—using IAM, so that the users can access Google Cloud services. +// Workforce identity federation extends Google Cloud's identity capabilities to support +// syncless, attribute-based single sign on. +// +// With workforce identity federation, your workforce can access Google Cloud resources +// using an external identity provider (IdP) that supports OpenID Connect (OIDC) or +// SAML 2.0 such as Azure Active Directory (Azure AD), Active Directory Federation +// Services (AD FS), Okta, and others. +// +// Follow the detailed instructions on how to configure Workload Identity Federation +// in various platforms: +// +// Azure AD: https://cloud.google.com/iam/docs/workforce-sign-in-azure-ad +// Okta: https://cloud.google.com/iam/docs/workforce-sign-in-okta +// OIDC identity provider: https://cloud.google.com/iam/docs/configuring-workforce-identity-federation#oidc +// SAML 2.0 identity provider: https://cloud.google.com/iam/docs/configuring-workforce-identity-federation#saml +// +// For workforce identity federation, the library can retrieve tokens in three ways: +// from a local file location (file-sourced credentials), from a server +// (URL-sourced credentials), or from a local executable (executable-sourced +// credentials). +// For file-sourced credentials, a background process needs to be continuously +// refreshing the file location with a new OIDC/SAML token prior to expiration. +// For tokens with one hour lifetimes, the token needs to be updated in the file +// every hour. The token can be stored directly as plain text or in JSON format. +// For URL-sourced credentials, a local server needs to host a GET endpoint to +// return the OIDC/SAML token. The response can be in plain text or JSON. +// Additional required request headers can also be specified. +// For executable-sourced credentials, an application needs to be available to +// output the OIDC/SAML token and other information in a JSON format. +// For more information on how these work (and how to implement +// executable-sourced credentials), please check out: +// https://cloud.google.com/iam/docs/workforce-obtaining-short-lived-credentials#generate_a_configuration_file_for_non-interactive_sign-in +// +// Note that this library does not perform any validation on the token_url, token_info_url, +// or service_account_impersonation_url fields of the credential configuration. +// It is not recommended to use a credential configuration that you did not generate with +// the gcloud CLI unless you verify that the URL fields point to a googleapis.com domain. // // # Credentials // diff --git a/google/google.go b/google/google.go index 8df0c493e..cc1223889 100644 --- a/google/google.go +++ b/google/google.go @@ -26,6 +26,9 @@ var Endpoint = oauth2.Endpoint{ AuthStyle: oauth2.AuthStyleInParams, } +// MTLSTokenURL is Google's OAuth 2.0 default mTLS endpoint. +const MTLSTokenURL = "https://oauth2.mtls.googleapis.com/token" + // JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow. const JWTTokenURL = "https://oauth2.googleapis.com/token" @@ -172,7 +175,11 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar cfg.Endpoint.AuthURL = Endpoint.AuthURL } if cfg.Endpoint.TokenURL == "" { - cfg.Endpoint.TokenURL = Endpoint.TokenURL + if params.TokenURL != "" { + cfg.Endpoint.TokenURL = params.TokenURL + } else { + cfg.Endpoint.TokenURL = Endpoint.TokenURL + } } tok := &oauth2.Token{RefreshToken: f.RefreshToken} return cfg.TokenSource(ctx, tok), nil @@ -224,7 +231,11 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar // Further information about retrieving access tokens from the GCE metadata // server can be found at https://cloud.google.com/compute/docs/authentication. func ComputeTokenSource(account string, scope ...string) oauth2.TokenSource { - return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope}) + return computeTokenSource(account, 0, scope...) +} + +func computeTokenSource(account string, earlyExpiry time.Duration, scope ...string) oauth2.TokenSource { + return oauth2.ReuseTokenSourceWithExpiry(nil, computeSource{account: account, scopes: scope}, earlyExpiry) } type computeSource struct { diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index e917195d5..2bf3202b2 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -62,6 +62,13 @@ const ( // The AWS authorization header name for the auto-generated date. awsDateHeader = "x-amz-date" + // Supported AWS configuration environment variables. + awsAccessKeyId = "AWS_ACCESS_KEY_ID" + awsDefaultRegion = "AWS_DEFAULT_REGION" + awsRegion = "AWS_REGION" + awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + awsSessionToken = "AWS_SESSION_TOKEN" + awsTimeFormatLong = "20060102T150405Z" awsTimeFormatShort = "20060102" ) @@ -267,6 +274,49 @@ type awsRequest struct { Headers []awsRequestHeader `json:"headers"` } +func (cs awsCredentialSource) validateMetadataServers() error { + if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil { + return err + } + if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil { + return err + } + return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url") +} + +var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"} + +func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool { + if metadataUrl == "" { + // Zero value means use default, which is valid. + return true + } + + u, err := url.Parse(metadataUrl) + if err != nil { + // Unparseable URL means invalid + return false + } + + for _, validHostname := range validHostnames { + if u.Hostname() == validHostname { + // If it's one of the valid hostnames, everything is good + return true + } + } + + // hostname not found in our allowlist, so not valid + return false +} + +func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error { + if !cs.isValidMetadataServer(metadataUrl) { + return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName) + } + + return nil +} + func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { if cs.client == nil { cs.client = oauth2.NewClient(cs.ctx, nil) @@ -274,16 +324,33 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro return cs.client.Do(req.WithContext(cs.ctx)) } +func canRetrieveRegionFromEnvironment() bool { + // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is + // required. + return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != "" +} + +func canRetrieveSecurityCredentialFromEnvironment() bool { + // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available. + return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" +} + +func shouldUseMetadataServer() bool { + return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() +} + func (cs awsCredentialSource) subjectToken() (string, error) { if cs.requestSigner == nil { - awsSessionToken, err := cs.getAWSSessionToken() - if err != nil { - return "", err - } - headers := make(map[string]string) - if awsSessionToken != "" { - headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + if shouldUseMetadataServer() { + awsSessionToken, err := cs.getAWSSessionToken() + if err != nil { + return "", err + } + + if awsSessionToken != "" { + headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + } } awsSecurityCredentials, err := cs.getSecurityCredentials(headers) @@ -389,11 +456,11 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { - if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { - return envAwsRegion, nil - } - if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" { - return envAwsRegion, nil + if canRetrieveRegionFromEnvironment() { + if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { + return envAwsRegion, nil + } + return getenv("AWS_DEFAULT_REGION"), nil } if cs.RegionURL == "" { @@ -434,14 +501,12 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err } func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { - if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { - if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { - return awsSecurityCredentials{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SecurityToken: getenv("AWS_SESSION_TOKEN"), - }, nil - } + if canRetrieveSecurityCredentialFromEnvironment() { + return awsSecurityCredentials{ + AccessKeyID: getenv(awsAccessKeyId), + SecretAccessKey: getenv(awsSecretAccessKey), + SecurityToken: getenv(awsSessionToken), + }, nil } roleName, err := cs.getMetadataRoleName(headers) diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 093438925..058b00424 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -474,6 +474,38 @@ func createDefaultAwsTestServer() *testAwsServer { ) } +func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer { + validateSessionTokenHeaders := func(r *http.Request) { + if r.URL.Path == "/latest/api/token" { + headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) + if headerValue != awsIMDSv2SessionTtl { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) + } + } else { + headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) + if headerValue != "sessiontoken" { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") + } + } + } + + return createAwsTestServer( + "/latest/meta-data/iam/security-credentials", + "/latest/meta-data/placement/availability-zone", + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "/latest/api/token", + "gcp-aws-role", + "us-east-2b", + map[string]string{ + "SecretAccessKey": secretAccessKey, + "AccessKeyId": accessKeyID, + "Token": securityToken, + }, + "sessiontoken", + validateSessionTokenHeaders, + ) +} + func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch p := r.URL.Path; p { case server.url: @@ -553,16 +585,25 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security func TestAWSCredential_BasicRequest(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -588,46 +629,27 @@ func TestAWSCredential_BasicRequest(t *testing.T) { } func TestAWSCredential_IMDSv2(t *testing.T) { - validateSessionTokenHeaders := func(r *http.Request) { - if r.URL.Path == "/latest/api/token" { - headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) - if headerValue != awsIMDSv2SessionTtl { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) - } - } else { - headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) - if headerValue != "sessiontoken" { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") - } - } - } - - server := createAwsTestServer( - "/latest/meta-data/iam/security-credentials", - "/latest/meta-data/placement/availability-zone", - "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", - "/latest/api/token", - "gcp-aws-role", - "us-east-2b", - map[string]string{ - "SecretAccessKey": secretAccessKey, - "AccessKeyId": accessKeyID, - "Token": securityToken, - }, - "sessiontoken", - validateSessionTokenHeaders, - ) + server := createDefaultAwsTestServerWithImdsv2(t) ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -655,17 +677,26 @@ func TestAWSCredential_IMDSv2(t *testing.T) { func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } delete(server.Credentials, "Token") tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -693,20 +724,29 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -734,20 +774,29 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - "AWS_DEFAULT_REGION": "us-west-1", + "AWS_REGION": "us-west-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -774,21 +823,30 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -815,16 +873,25 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { func TestAWSCredential_RequestWithBadVersion(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.EnvironmentID = "aws3" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} - _, err := tfc.parse(context.Background()) + _, err = tfc.parse(context.Background()) if err == nil { t.Fatalf("parse() should have failed") } @@ -836,14 +903,23 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.RegionURL = "" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -863,14 +939,23 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteRegion = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -890,6 +975,10 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) } @@ -898,8 +987,13 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -919,6 +1013,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) } @@ -927,8 +1025,13 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -948,14 +1051,23 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.URL = "" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -975,14 +1087,23 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteRolename = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -1002,14 +1123,23 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -1025,3 +1155,290 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { t.Errorf("subjectToken = %q, want %q", got, want) } } + +func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("Metadata server should not have been called.") + })) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + tfc.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + "AKIDEXAMPLE", + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": accessKeyID, + "AWS_SECRET_ACCESS_KEY": secretAccessKey, + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_Validations(t *testing.T) { + var metadataServerValidityTests = []struct { + name string + credSource CredentialSource + errText string + }{ + { + name: "No Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "", + URL: "", + IMDSv2SessionTokenURL: "", + }, + }, { + name: "IPv4 Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + }, { + name: "IPv6 Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone", + URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token", + }, + }, { + name: "Faulty RegionURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url", + }, { + name: "Faulty CredVerificationURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://abc.com/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url", + }, { + name: "Faulty IMDSv2SessionTokenURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://abc.com/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url", + }, + } + + for _, tt := range metadataServerValidityTests { + t.Run(tt.name, func(t *testing.T) { + tfc := testFileConfig + tfc.CredentialSource = tt.credSource + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{}) + + _, err := tfc.parse(context.Background()) + if err != nil { + if tt.errText == "" { + t.Errorf("Didn't expect an error, but got %v", err) + } else if tt.errText != err.Error() { + t.Errorf("Expected %v, but got %v", tt.errText, err) + } + } else { + if tt.errText != "" { + t.Errorf("Expected error %v, but got none", tt.errText) + } + } + }) + } +} diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 9fc35535e..dcd252a61 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -67,22 +67,6 @@ type Config struct { // that include all elements in a given list, in that order. var ( - validTokenURLPatterns = []*regexp.Regexp{ - // The complicated part in the middle matches any number of characters that - // aren't period, spaces, or slashes. - regexp.MustCompile(`(?i)^[^\.\s\/\\]+\.sts\.googleapis\.com$`), - regexp.MustCompile(`(?i)^sts\.googleapis\.com$`), - regexp.MustCompile(`(?i)^sts\.[^\.\s\/\\]+\.googleapis\.com$`), - regexp.MustCompile(`(?i)^[^\.\s\/\\]+-sts\.googleapis\.com$`), - regexp.MustCompile(`(?i)^sts-[^\.\s\/\\]+\.p\.googleapis\.com$`), - } - validImpersonateURLPatterns = []*regexp.Regexp{ - regexp.MustCompile(`^[^\.\s\/\\]+\.iamcredentials\.googleapis\.com$`), - regexp.MustCompile(`^iamcredentials\.googleapis\.com$`), - regexp.MustCompile(`^iamcredentials\.[^\.\s\/\\]+\.googleapis\.com$`), - regexp.MustCompile(`^[^\.\s\/\\]+-iamcredentials\.googleapis\.com$`), - regexp.MustCompile(`^iamcredentials-[^\.\s\/\\]+\.p\.googleapis\.com$`), - } validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`) ) @@ -110,25 +94,13 @@ func validateWorkforceAudience(input string) bool { // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { - return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https") + return c.tokenSource(ctx, "https") } // tokenSource is a private function that's directly called by some of the tests, // because the unit test URLs are mocked, and would otherwise fail the // validity check. -func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp, scheme string) (oauth2.TokenSource, error) { - valid := validateURL(c.TokenURL, tokenURLValidPats, scheme) - if !valid { - return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource") - } - - if c.ServiceAccountImpersonationURL != "" { - valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats, scheme) - if !valid { - return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource") - } - } - +func (c *Config) tokenSource(ctx context.Context, scheme string) (oauth2.TokenSource, error) { if c.WorkforcePoolUserProject != "" { valid := validateWorkforceAudience(c.Audience) if !valid { @@ -213,6 +185,10 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL } + if err := awsCredSource.validateMetadataServers(); err != nil { + return nil, err + } + return awsCredSource, nil } } else if c.CredentialSource.File != "" { diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 05e0127f0..bf6be321c 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "strings" "testing" "time" @@ -208,140 +207,6 @@ func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) { } } -func TestValidateURLTokenURL(t *testing.T) { - var urlValidityTests = []struct { - tokURL string - expectSuccess bool - }{ - {"https://east.sts.googleapis.com", true}, - {"https://sts.googleapis.com", true}, - {"https://sts.asfeasfesef.googleapis.com", true}, - {"https://us-east-1-sts.googleapis.com", true}, - {"https://sts.googleapis.com/your/path/here", true}, - {"https://.sts.googleapis.com", false}, - {"https://badsts.googleapis.com", false}, - {"https://sts.asfe.asfesef.googleapis.com", false}, - {"https://sts..googleapis.com", false}, - {"https://-sts.googleapis.com", false}, - {"https://us-ea.st-1-sts.googleapis.com", false}, - {"https://sts.googleapis.com.evil.com/whatever/path", false}, - {"https://us-eas\\t-1.sts.googleapis.com", false}, - {"https:/us-ea/st-1.sts.googleapis.com", false}, - {"https:/us-east 1.sts.googleapis.com", false}, - {"https://", false}, - {"http://us-east-1.sts.googleapis.com", false}, - {"https://us-east-1.sts.googleapis.comevil.com", false}, - {"https://sts-xyz.p.googleapis.com", true}, - {"https://sts.pgoogleapis.com", false}, - {"https://p.googleapis.com", false}, - {"https://sts.p.com", false}, - {"http://sts.p.googleapis.com", false}, - {"https://xyz-sts.p.googleapis.com", false}, - {"https://sts-xyz.123.p.googleapis.com", false}, - {"https://sts-xyz.p1.googleapis.com", false}, - {"https://sts-xyz.p.foo.com", false}, - {"https://sts-xyz.p.foo.googleapis.com", false}, - } - ctx := context.Background() - for _, tt := range urlValidityTests { - t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = tt.tokURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } - for _, el := range urlValidityTests { - el.tokURL = strings.ToUpper(el.tokURL) - } - for _, tt := range urlValidityTests { - t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = tt.tokURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } -} - -func TestValidateURLImpersonateURL(t *testing.T) { - var urlValidityTests = []struct { - impURL string - expectSuccess bool - }{ - {"https://east.iamcredentials.googleapis.com", true}, - {"https://iamcredentials.googleapis.com", true}, - {"https://iamcredentials.asfeasfesef.googleapis.com", true}, - {"https://us-east-1-iamcredentials.googleapis.com", true}, - {"https://iamcredentials.googleapis.com/your/path/here", true}, - {"https://.iamcredentials.googleapis.com", false}, - {"https://badiamcredentials.googleapis.com", false}, - {"https://iamcredentials.asfe.asfesef.googleapis.com", false}, - {"https://iamcredentials..googleapis.com", false}, - {"https://-iamcredentials.googleapis.com", false}, - {"https://us-ea.st-1-iamcredentials.googleapis.com", false}, - {"https://iamcredentials.googleapis.com.evil.com/whatever/path", false}, - {"https://us-eas\\t-1.iamcredentials.googleapis.com", false}, - {"https:/us-ea/st-1.iamcredentials.googleapis.com", false}, - {"https:/us-east 1.iamcredentials.googleapis.com", false}, - {"https://", false}, - {"http://us-east-1.iamcredentials.googleapis.com", false}, - {"https://us-east-1.iamcredentials.googleapis.comevil.com", false}, - {"https://iamcredentials-xyz.p.googleapis.com", true}, - {"https://iamcredentials.pgoogleapis.com", false}, - {"https://p.googleapis.com", false}, - {"https://iamcredentials.p.com", false}, - {"http://iamcredentials.p.googleapis.com", false}, - {"https://xyz-iamcredentials.p.googleapis.com", false}, - {"https://iamcredentials-xyz.123.p.googleapis.com", false}, - {"https://iamcredentials-xyz.p1.googleapis.com", false}, - {"https://iamcredentials-xyz.p.foo.com", false}, - {"https://iamcredentials-xyz.p.foo.googleapis.com", false}, - } - ctx := context.Background() - for _, tt := range urlValidityTests { - t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL - config.ServiceAccountImpersonationURL = tt.impURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } - for _, el := range urlValidityTests { - el.impURL = strings.ToUpper(el.impURL) - } - for _, tt := range urlValidityTests { - t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - config := testConfig - config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL - config.ServiceAccountImpersonationURL = tt.impURL - _, err := config.TokenSource(ctx) - - if tt.expectSuccess && err != nil { - t.Errorf("got %v but want nil", err) - } else if !tt.expectSuccess && err == nil { - t.Errorf("got nil but expected an error") - } - }) - } -} - func TestWorkforcePoolCreation(t *testing.T) { var audienceValidatyTests = []struct { audience string diff --git a/google/internal/externalaccount/impersonate_test.go b/google/internal/externalaccount/impersonate_test.go index 17e2f6d72..8c7f6a9a7 100644 --- a/google/internal/externalaccount/impersonate_test.go +++ b/google/internal/externalaccount/impersonate_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net/http" "net/http/httptest" - "regexp" "testing" ) @@ -114,8 +113,7 @@ func TestImpersonation(t *testing.T) { defer targetServer.Close() testImpersonateConfig.TokenURL = targetServer.URL - allURLs := regexp.MustCompile(".+") - ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}, "http") + ourTS, err := testImpersonateConfig.tokenSource(context.Background(), "http") if err != nil { t.Fatalf("Failed to create TokenSource: %v", err) } diff --git a/oauth2.go b/oauth2.go index 5cb7ca175..0fa61292d 100644 --- a/oauth2.go +++ b/oauth2.go @@ -16,6 +16,7 @@ import ( "net/url" "strings" "sync" + "time" "golang.org/x/oauth2/advancedauth" "golang.org/x/oauth2/internal" @@ -316,6 +317,8 @@ type reuseTokenSource struct { mu sync.Mutex // guards t t *Token + + expiryDelta time.Duration } // Token returns the current token if it's still valid, else will @@ -331,6 +334,7 @@ func (s *reuseTokenSource) Token() (*Token, error) { if err != nil { return nil, err } + t.expiryDelta = s.expiryDelta s.t = t return t, nil } @@ -405,3 +409,30 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource { new: src, } } + +// ReuseTokenSource returns a TokenSource that acts in the same manner as the +// TokenSource returned by ReuseTokenSource, except the expiry buffer is +// configurable. The expiration time of a token is calculated as +// t.Expiry.Add(-earlyExpiry). +func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource { + // Don't wrap a reuseTokenSource in itself. That would work, + // but cause an unnecessary number of mutex operations. + // Just build the equivalent one. + if rt, ok := src.(*reuseTokenSource); ok { + if t == nil { + // Just use it directly, but set the expiryDelta to earlyExpiry, + // so the behavior matches what the user expects. + rt.expiryDelta = earlyExpiry + return rt + } + src = rt.new + } + if t != nil { + t.expiryDelta = earlyExpiry + } + return &reuseTokenSource{ + t: t, + new: src, + expiryDelta: earlyExpiry, + } +} diff --git a/token.go b/token.go index 822720341..55a1fedca 100644 --- a/token.go +++ b/token.go @@ -21,6 +21,11 @@ import ( // expirations due to client-server time mismatches. const expiryDelta = 10 * time.Second +// defaultExpiryDelta determines how earlier a token should be considered +// expired than its actual expiration time. It is used to avoid late +// expirations due to client-server time mismatches. +const defaultExpiryDelta = 10 * time.Second + // Token represents the credentials used to authorize // the requests to access protected resources on the OAuth 2.0 // provider's backend. @@ -52,6 +57,11 @@ type Token struct { // raw optionally contains extra metadata from the server // when updating a token. raw interface{} + + // expiryDelta is used to calculate when a token is considered + // expired, by subtracting from Expiry. If zero, defaultExpiryDelta + // is used. + expiryDelta time.Duration } // Type returns t.TokenType if non-empty, else "Bearer". From 82e1fc748f8d74ef0b8108a20ed0c127455ee632 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Tue, 5 Sep 2023 13:15:33 +0200 Subject: [PATCH 13/16] Singleflight token retrieval (#9) --- go.mod | 2 ++ go.sum | 4 ++++ internal/token.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/go.mod b/go.mod index 5c92bf05e..7f991198c 100644 --- a/go.mod +++ b/go.mod @@ -13,5 +13,7 @@ require ( require ( github.com/golang/protobuf v1.5.2 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect + golang.org/x/sync v0.3.0 // indirect google.golang.org/protobuf v1.28.0 // indirect ) diff --git a/go.sum b/go.sum index 25f58b344..01574f372 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -23,6 +25,8 @@ golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/token.go b/internal/token.go index 355c38696..7675d7693 100644 --- a/internal/token.go +++ b/internal/token.go @@ -20,7 +20,9 @@ import ( "sync" "time" + "github.com/mitchellh/hashstructure/v2" "golang.org/x/net/context/ctxhttp" + "golang.org/x/sync/singleflight" ) // Token represents the credentials used to authorize @@ -185,7 +187,44 @@ func cloneURLValues(v url.Values) url.Values { return v2 } +var tokenFetchGroup singleflight.Group + func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { + // singleflight wrapper over the actual implementation `doRetrieveToken` + // this function makes sure that token endpoint is called only once at the same time with + // the same credentials and params + var ( + hashStr = struct { + clientID string + clientSecret string + tokenURL string + v url.Values + authStyle AuthStyle + }{ + clientID: clientID, + clientSecret: clientSecret, + tokenURL: tokenURL, + v: v, + authStyle: authStyle, + } + hash uint64 + token interface{} + err error + ) + + if hash, err = hashstructure.Hash(hashStr, hashstructure.FormatV2, nil); err != nil { + return nil, err + } + + if token, err, _ = tokenFetchGroup.Do(strconv.FormatUint(hash, 10), func() (interface{}, error) { + return doRetrieveToken(ctx, clientID, clientSecret, tokenURL, v, authStyle) + }); err != nil { + return nil, err + } + return token.(*Token), nil +} + +func doRetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { needsAuthStyleProbe := authStyle == 0 if needsAuthStyleProbe { if style, ok := lookupAuthStyle(tokenURL); ok { From 43b2ba1f23dcee52057ecfebeed0715d8892805e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ho=C5=82owi=C5=84ski?= Date: Tue, 5 Sep 2023 16:46:13 +0200 Subject: [PATCH 14/16] Fix hashing calculation for singleglight (#10) --- internal/token.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/token.go b/internal/token.go index 7675d7693..49e526410 100644 --- a/internal/token.go +++ b/internal/token.go @@ -195,17 +195,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, // the same credentials and params var ( hashStr = struct { - clientID string - clientSecret string - tokenURL string - v url.Values - authStyle AuthStyle + ClientID string + ClientSecret string + TokenURL string + V url.Values + AuthStyle AuthStyle }{ - clientID: clientID, - clientSecret: clientSecret, - tokenURL: tokenURL, - v: v, - authStyle: authStyle, + ClientID: clientID, + ClientSecret: clientSecret, + TokenURL: tokenURL, + V: v, + AuthStyle: authStyle, } hash uint64 token interface{} From 9e1678c8994e7125355a826a5168d956c467fbc1 Mon Sep 17 00:00:00 2001 From: Mateusz Bilski Date: Tue, 16 Jul 2024 15:47:25 +0200 Subject: [PATCH 15/16] Add device auth url to endpoint (#11) --- oauth2.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/oauth2.go b/oauth2.go index 0fa61292d..ab299bbb7 100644 --- a/oauth2.go +++ b/oauth2.go @@ -76,8 +76,9 @@ type TokenSource interface { // Endpoint represents an OAuth 2.0 provider's authorization and token // endpoint URLs. type Endpoint struct { - AuthURL string - TokenURL string + AuthURL string + DeviceAuthURL string + TokenURL string // AuthStyle optionally specifies how the endpoint wants the // client ID & client secret sent. The zero value means to @@ -298,7 +299,6 @@ func (tf *tokenRefresher) Token() (*Token, error) { "grant_type": {"refresh_token"}, "refresh_token": {tf.refreshToken}, }) - if err != nil { return nil, err } From 93277f96cd8fc46110aab459c93234179e9c0405 Mon Sep 17 00:00:00 2001 From: Mateusz Bilski Date: Tue, 12 Nov 2024 17:32:54 +0100 Subject: [PATCH 16/16] Sync (#12) --- google/default.go | 15 +++++++++++++++ internal/token.go | 7 +++++-- token.go | 6 ++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/google/default.go b/google/default.go index 2cf71f0f9..818513009 100644 --- a/google/default.go +++ b/google/default.go @@ -37,6 +37,21 @@ type Credentials struct { // environment and not with a credentials file, e.g. when code is // running on Google Cloud Platform. JSON []byte + + // UniverseDomainProvider returns the default service domain for a given + // Cloud universe. Optional. + // + // On GCE, UniverseDomainProvider should return the universe domain value + // from Google Compute Engine (GCE)'s metadata server. See also [The attached service + // account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa). + // If the GCE metadata server returns a 404 error, the default universe + // domain value should be returned. If the GCE metadata server returns an + // error other than 404, the error should be returned. + UniverseDomainProvider func() (string, error) +} + +func (c *Credentials) GetUniverseDomain() (string, error) { + return "", nil } // DefaultCredentials is the old name of Credentials. diff --git a/internal/token.go b/internal/token.go index 49e526410..30e6d14cc 100644 --- a/internal/token.go +++ b/internal/token.go @@ -324,8 +324,11 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { } type RetrieveError struct { - Response *http.Response - Body []byte + Response *http.Response + Body []byte + ErrorCode string + ErrorDescription string + ErrorURI string } func (r *RetrieveError) Error() string { diff --git a/token.go b/token.go index 55a1fedca..ebdceb862 100644 --- a/token.go +++ b/token.go @@ -181,6 +181,12 @@ type RetrieveError struct { // Body is the body that was consumed by reading Response.Body. // It may be truncated. Body []byte + // ErrorCode is RFC 6749's 'error' parameter. + ErrorCode string + // ErrorDescription is RFC 6749's 'error_description' parameter. + ErrorDescription string + // ErrorURI is RFC 6749's 'error_uri' parameter. + ErrorURI string } func (r *RetrieveError) Error() string {