Skip to content

Commit

Permalink
Support token refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
jakob-keller committed Feb 22, 2023
1 parent dd56c69 commit 36940a0
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 4 deletions.
6 changes: 3 additions & 3 deletions aiobotocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
parse,
resolve_imds_endpoint_mode,
)
from botocore.tokens import SSOTokenProvider
from dateutil.tz import tzutc

from aiobotocore.config import AioConfig
from aiobotocore.tokens import AioSSOTokenProvider
from aiobotocore.utils import (
AioContainerMetadataFetcher,
AioInstanceMetadataFetcher,
Expand Down Expand Up @@ -192,7 +192,7 @@ def _create_sso_provider(self, profile_name):
profile_name=profile_name,
cache=self._cache,
token_cache=self._sso_token_cache,
token_provider=SSOTokenProvider(
token_provider=AioSSOTokenProvider(
self._session,
cache=self._sso_token_cache,
profile_name=profile_name,
Expand Down Expand Up @@ -1022,7 +1022,7 @@ async def _get_credentials(self):
async with self._client_creator('sso', config=config) as client:
if self._token_provider:
initial_token_data = self._token_provider.load_token()
token = initial_token_data.get_frozen_token().token
token = (await initial_token_data.get_frozen_token()).token
else:
token = self._token_loader(self._start_url)['accessToken']

Expand Down
4 changes: 4 additions & 0 deletions aiobotocore/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .credentials import AioCredentials, create_credential_resolver
from .hooks import AioHierarchicalEmitter
from .parsers import AioResponseParserFactory
from .tokens import create_token_resolver
from .utils import AioIMDSRegionProvider


Expand Down Expand Up @@ -47,6 +48,9 @@ def __init__(
session_vars, event_hooks, include_builtin_handlers, profile
)

def _create_token_resolver(self):
return create_token_resolver(self)

def _create_credential_resolver(self):
return create_credential_resolver(
self, region_name=self._last_client_region_used
Expand Down
2 changes: 1 addition & 1 deletion aiobotocore/signers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def get_auth_instance(
if cls.REQUIRES_TOKEN is True:
frozen_token = None
if self._auth_token is not None:
frozen_token = self._auth_token.get_frozen_token()
frozen_token = await self._auth_token.get_frozen_token()
auth = cls(frozen_token)
return auth

Expand Down
160 changes: 160 additions & 0 deletions aiobotocore/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import asyncio
import logging
from datetime import timedelta

import dateutil.parser
from botocore.compat import total_seconds
from botocore.exceptions import ClientError, TokenRetrievalError
from botocore.tokens import (
DeferredRefreshableToken,
FrozenAuthToken,
SSOTokenProvider,
TokenProviderChain,
_utc_now,
)

logger = logging.getLogger(__name__)


def create_token_resolver(session):
providers = [
AioSSOTokenProvider(session),
]
return TokenProviderChain(providers=providers)


class AioDeferredRefreshableToken(DeferredRefreshableToken):
def __init__(
self, method, refresh_using, time_fetcher=_utc_now
): # noqa: E501, lgtm [py/missing-call-to-init]
self._time_fetcher = time_fetcher
self._refresh_using = refresh_using
self.method = method

# The frozen token is protected by this lock
self._refresh_lock = asyncio.Lock()
self._frozen_token = None
self._next_refresh = None

async def get_frozen_token(self):
await self._refresh()
return self._frozen_token

async def _refresh(self):
# If we don't need to refresh just return
refresh_type = self._should_refresh()
if not refresh_type:
return None

# Block for refresh if we're in the mandatory refresh window
block_for_refresh = refresh_type == "mandatory"
if block_for_refresh or not self._refresh_lock.locked():
async with self._refresh_lock:
await self._protected_refresh()

async def _protected_refresh(self):
# This should only be called after acquiring the refresh lock
# Another task may have already refreshed, double check refresh
refresh_type = self._should_refresh()
if not refresh_type:
return None

try:
now = self._time_fetcher()
self._next_refresh = now + timedelta(seconds=self._attempt_timeout)
self._frozen_token = await self._refresh_using()
except Exception:
logger.warning(
"Refreshing token failed during the %s refresh period.",
refresh_type,
exc_info=True,
)
if refresh_type == "mandatory":
# This refresh was mandatory, error must be propagated back
raise

if self._is_expired():
# Fresh credentials should never be expired
raise TokenRetrievalError(
provider=self.method,
error_msg="Token has expired and refresh failed",
)


class AioSSOTokenProvider(SSOTokenProvider):
async def _attempt_create_token(self, token):
response = await self._client.create_token(
grantType=self._GRANT_TYPE,
clientId=token["clientId"],
clientSecret=token["clientSecret"],
refreshToken=token["refreshToken"],
)
expires_in = timedelta(seconds=response["expiresIn"])
new_token = {
"startUrl": self._sso_config["sso_start_url"],
"region": self._sso_config["sso_region"],
"accessToken": response["accessToken"],
"expiresAt": self._now() + expires_in,
# Cache the registration alongside the token
"clientId": token["clientId"],
"clientSecret": token["clientSecret"],
"registrationExpiresAt": token["registrationExpiresAt"],
}
if "refreshToken" in response:
new_token["refreshToken"] = response["refreshToken"]
logger.info("SSO Token refresh succeeded")
return new_token

async def _refresh_access_token(self, token):
keys = (
"refreshToken",
"clientId",
"clientSecret",
"registrationExpiresAt",
)
missing_keys = [k for k in keys if k not in token]
if missing_keys:
msg = f"Unable to refresh SSO token: missing keys: {missing_keys}"
logger.info(msg)
return None

expiry = dateutil.parser.parse(token["registrationExpiresAt"])
if total_seconds(expiry - self._now()) <= 0:
logger.info(f"SSO token registration expired at {expiry}")
return None

try:
return await self._attempt_create_token(token)
except ClientError:
logger.warning("SSO token refresh attempt failed", exc_info=True)
return None

async def _refresher(self):
start_url = self._sso_config["sso_start_url"]
session_name = self._sso_config["session_name"]
logger.info(f"Loading cached SSO token for {session_name}")
token_dict = self._token_loader(start_url, session_name=session_name)
expiration = dateutil.parser.parse(token_dict["expiresAt"])
logger.debug(f"Cached SSO token expires at {expiration}")

remaining = total_seconds(expiration - self._now())
if remaining < self._REFRESH_WINDOW:
new_token_dict = await self._refresh_access_token(token_dict)
if new_token_dict is not None:
token_dict = new_token_dict
expiration = token_dict["expiresAt"]
self._token_loader.save_token(
start_url, token_dict, session_name=session_name
)

return FrozenAuthToken(
token_dict["accessToken"], expiration=expiration
)

def load_token(self):
if self._sso_config is None:
return None

return AioDeferredRefreshableToken(
self.METHOD, self._refresher, time_fetcher=self._now
)
30 changes: 30 additions & 0 deletions tests/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@
generate_presigned_post,
generate_presigned_url,
)
from botocore.tokens import (
DeferredRefreshableToken,
SSOTokenProvider,
create_token_resolver,
)
from botocore.utils import (
ContainerMetadataFetcher,
IMDSFetcher,
Expand Down Expand Up @@ -400,6 +405,9 @@
'bb8f7f3cc4d9ff9551f0875604747c4bb5030ff6'
},
Session.create_client: {'8b1bd136aba5d0e519816aca7354b3d1e2dee7ec'},
Session._create_token_resolver: {
'142df7a219db0dd9c96fd81dc9e84a764a2fe5fb'
},
Session._create_credential_resolver: {
'87e98d201c72d06f7fbdb4ebee2dce1c09de0fb2'
},
Expand Down Expand Up @@ -433,6 +441,28 @@
generate_presigned_post: {'1b48275e09e9c1f872a1d16e74d7e40f34cfaf90'},
add_generate_db_auth_token: {'f61014e6fac4b5c7ee7ac2d2bec15fb16fa9fbe5'},
generate_db_auth_token: {'1f37e1e5982d8528841ce6b79f229b3e23a18959'},
# tokens.py
create_token_resolver: {'b287f4879235a4292592a49b201d2b0bc2dbf401'},
DeferredRefreshableToken.__init__: {
'199254ed7e211119bdebf285c5d9a9789f6dc540'
},
DeferredRefreshableToken.get_frozen_token: {
'846a689a25550c63d2a460555dc27148abdcc992'
},
DeferredRefreshableToken._refresh: {
'92af1e549b5719caa246a81493823a37a684d017'
},
DeferredRefreshableToken._protected_refresh: {
'bd5c1911626e420005e0e60d583a73c68925f4b6'
},
SSOTokenProvider._attempt_create_token: {
'9cf7b75618a253d585819485e5da641cef129d46'
},
SSOTokenProvider._refresh_access_token: {
'cb179d1f262e41cc03a7c218e624e8c7fbeeaf19'
},
SSOTokenProvider._refresher: {'824d41775dbb8a05184f6e9c7b2ea7202b72f2a9'},
SSOTokenProvider.load_token: {'aea8584ef3fb83948ed82f2a2518eec40fb537a0'},
# utils.py
ContainerMetadataFetcher.__init__: {
'46d90a7249ba8389feb487779b0a02e6faa98e57'
Expand Down

0 comments on commit 36940a0

Please sign in to comment.