-
-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd56c69
commit 36940a0
Showing
5 changed files
with
198 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import asyncio | ||
import logging | ||
from datetime import timedelta | ||
|
||
import dateutil.parser | ||
from botocore.compat import total_seconds | ||
from botocore.exceptions import ClientError, TokenRetrievalError | ||
from botocore.tokens import ( | ||
DeferredRefreshableToken, | ||
FrozenAuthToken, | ||
SSOTokenProvider, | ||
TokenProviderChain, | ||
_utc_now, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def create_token_resolver(session): | ||
providers = [ | ||
AioSSOTokenProvider(session), | ||
] | ||
return TokenProviderChain(providers=providers) | ||
|
||
|
||
class AioDeferredRefreshableToken(DeferredRefreshableToken): | ||
def __init__( | ||
self, method, refresh_using, time_fetcher=_utc_now | ||
): # noqa: E501, lgtm [py/missing-call-to-init] | ||
self._time_fetcher = time_fetcher | ||
self._refresh_using = refresh_using | ||
self.method = method | ||
|
||
# The frozen token is protected by this lock | ||
self._refresh_lock = asyncio.Lock() | ||
self._frozen_token = None | ||
self._next_refresh = None | ||
|
||
async def get_frozen_token(self): | ||
await self._refresh() | ||
return self._frozen_token | ||
|
||
async def _refresh(self): | ||
# If we don't need to refresh just return | ||
refresh_type = self._should_refresh() | ||
if not refresh_type: | ||
return None | ||
|
||
# Block for refresh if we're in the mandatory refresh window | ||
block_for_refresh = refresh_type == "mandatory" | ||
if block_for_refresh or not self._refresh_lock.locked(): | ||
async with self._refresh_lock: | ||
await self._protected_refresh() | ||
|
||
async def _protected_refresh(self): | ||
# This should only be called after acquiring the refresh lock | ||
# Another task may have already refreshed, double check refresh | ||
refresh_type = self._should_refresh() | ||
if not refresh_type: | ||
return None | ||
|
||
try: | ||
now = self._time_fetcher() | ||
self._next_refresh = now + timedelta(seconds=self._attempt_timeout) | ||
self._frozen_token = await self._refresh_using() | ||
except Exception: | ||
logger.warning( | ||
"Refreshing token failed during the %s refresh period.", | ||
refresh_type, | ||
exc_info=True, | ||
) | ||
if refresh_type == "mandatory": | ||
# This refresh was mandatory, error must be propagated back | ||
raise | ||
|
||
if self._is_expired(): | ||
# Fresh credentials should never be expired | ||
raise TokenRetrievalError( | ||
provider=self.method, | ||
error_msg="Token has expired and refresh failed", | ||
) | ||
|
||
|
||
class AioSSOTokenProvider(SSOTokenProvider): | ||
async def _attempt_create_token(self, token): | ||
response = await self._client.create_token( | ||
grantType=self._GRANT_TYPE, | ||
clientId=token["clientId"], | ||
clientSecret=token["clientSecret"], | ||
refreshToken=token["refreshToken"], | ||
) | ||
expires_in = timedelta(seconds=response["expiresIn"]) | ||
new_token = { | ||
"startUrl": self._sso_config["sso_start_url"], | ||
"region": self._sso_config["sso_region"], | ||
"accessToken": response["accessToken"], | ||
"expiresAt": self._now() + expires_in, | ||
# Cache the registration alongside the token | ||
"clientId": token["clientId"], | ||
"clientSecret": token["clientSecret"], | ||
"registrationExpiresAt": token["registrationExpiresAt"], | ||
} | ||
if "refreshToken" in response: | ||
new_token["refreshToken"] = response["refreshToken"] | ||
logger.info("SSO Token refresh succeeded") | ||
return new_token | ||
|
||
async def _refresh_access_token(self, token): | ||
keys = ( | ||
"refreshToken", | ||
"clientId", | ||
"clientSecret", | ||
"registrationExpiresAt", | ||
) | ||
missing_keys = [k for k in keys if k not in token] | ||
if missing_keys: | ||
msg = f"Unable to refresh SSO token: missing keys: {missing_keys}" | ||
logger.info(msg) | ||
return None | ||
|
||
expiry = dateutil.parser.parse(token["registrationExpiresAt"]) | ||
if total_seconds(expiry - self._now()) <= 0: | ||
logger.info(f"SSO token registration expired at {expiry}") | ||
return None | ||
|
||
try: | ||
return await self._attempt_create_token(token) | ||
except ClientError: | ||
logger.warning("SSO token refresh attempt failed", exc_info=True) | ||
return None | ||
|
||
async def _refresher(self): | ||
start_url = self._sso_config["sso_start_url"] | ||
session_name = self._sso_config["session_name"] | ||
logger.info(f"Loading cached SSO token for {session_name}") | ||
token_dict = self._token_loader(start_url, session_name=session_name) | ||
expiration = dateutil.parser.parse(token_dict["expiresAt"]) | ||
logger.debug(f"Cached SSO token expires at {expiration}") | ||
|
||
remaining = total_seconds(expiration - self._now()) | ||
if remaining < self._REFRESH_WINDOW: | ||
new_token_dict = await self._refresh_access_token(token_dict) | ||
if new_token_dict is not None: | ||
token_dict = new_token_dict | ||
expiration = token_dict["expiresAt"] | ||
self._token_loader.save_token( | ||
start_url, token_dict, session_name=session_name | ||
) | ||
|
||
return FrozenAuthToken( | ||
token_dict["accessToken"], expiration=expiration | ||
) | ||
|
||
def load_token(self): | ||
if self._sso_config is None: | ||
return None | ||
|
||
return AioDeferredRefreshableToken( | ||
self.METHOD, self._refresher, time_fetcher=self._now | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters