From 5a5adfed5f52607d89a07cd04d796c7cec55ec12 Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Thu, 25 Mar 2021 01:15:28 -0700 Subject: [PATCH] feat: add reauth support to oauth2 credentials --- google/auth/_helpers.py | 34 ++++ google/auth/exceptions.py | 9 + google/oauth2/_client.py | 135 ++++++++++--- google/oauth2/challenges.py | 159 +++++++++++++++ google/oauth2/credentials.py | 33 +++- google/oauth2/reauth.py | 329 +++++++++++++++++++++++++++++++ noxfile.py | 1 + setup.py | 1 + tests/oauth2/test__client.py | 42 +++- tests/oauth2/test_challenges.py | 141 +++++++++++++ tests/oauth2/test_credentials.py | 46 ++++- tests/oauth2/test_reauth.py | 299 ++++++++++++++++++++++++++++ tests/test__helpers.py | 11 ++ 13 files changed, 1192 insertions(+), 48 deletions(-) create mode 100644 google/oauth2/challenges.py create mode 100644 google/oauth2/reauth.py create mode 100644 tests/oauth2/test_challenges.py create mode 100644 tests/oauth2/test_reauth.py diff --git a/google/auth/_helpers.py b/google/auth/_helpers.py index 21c987a73..fae160454 100644 --- a/google/auth/_helpers.py +++ b/google/auth/_helpers.py @@ -17,6 +17,8 @@ import base64 import calendar import datetime +import getpass +import sys import six from six.moves import urllib @@ -230,3 +232,35 @@ def unpadded_urlsafe_b64encode(value): Union[str|bytes]: The encoded value """ return base64.urlsafe_b64encode(value).rstrip(b"=") + + +def get_user_password(text): + """Get password from user. + + Override this function with a different logic if you are using this library + outside a CLI. + + Args: + text (str): message for the password prompt. + + Returns: + str: password string. + """ + return getpass.getpass(text) + + +def is_interactive(): + """Check if we are in an interractive environment. + + Override this function with a different logic if you are using this library + outside a CLI. + + If the rapt token needs refreshing, the user needs to answer the challenges. + If the user is not in an interractive environment, the challenges can not + be answered and we just wait for timeout for no reason. + + Returns: + bool: True if is interactive environment, False otherwise. + """ + + return sys.stdin.isatty() diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py index b6f686bbb..57f181ea1 100644 --- a/google/auth/exceptions.py +++ b/google/auth/exceptions.py @@ -48,3 +48,12 @@ class ClientCertError(GoogleAuthError): class OAuthError(GoogleAuthError): """Used to indicate an error occurred during an OAuth related HTTP request.""" + + +class ReauthFailError(RefreshError): + """An exception for when reauth failed.""" + + def __init__(self, message=None): + super(ReauthFailError, self).__init__( + "Reauthentication failed. {0}".format(message) + ) diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 448716329..add563f29 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -39,25 +39,24 @@ _REFRESH_GRANT_TYPE = "refresh_token" -def _handle_error_response(response_body): - """"Translates an error response into an exception. +def _handle_error_response(response_data): + """Translates an error response into an exception. Args: - response_body (str): The decoded response data. + response_data (Mapping): The decoded response data. Raises: - google.auth.exceptions.RefreshError + google.auth.exceptions.RefreshError: The errors contained in response_data. """ try: - error_data = json.loads(response_body) error_details = "{}: {}".format( - error_data["error"], error_data.get("error_description") + response_data["error"], response_data.get("error_description") ) # If no details could be extracted, use the response data. except (KeyError, ValueError): - error_details = response_body + error_details = json.dumps(response_data) - raise exceptions.RefreshError(error_details, response_body) + raise exceptions.RefreshError(error_details, response_data) def _parse_expiry(response_data): @@ -78,8 +77,11 @@ def _parse_expiry(response_data): return None -def _token_endpoint_request(request, token_uri, body): +def _token_endpoint_request_no_throw( + request, token_uri, body, access_token=None, use_json=False +): """Makes a request to the OAuth 2.0 authorization server's token endpoint. + This function doesn't throw on response errors. Args: request (google.auth.transport.Request): A callable used to make @@ -87,22 +89,33 @@ def _token_endpoint_request(request, token_uri, body): token_uri (str): The OAuth 2.0 authorizations server's token endpoint URI. body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. Returns: - Mapping[str, str]: The JSON-decoded response data. - - Raises: - google.auth.exceptions.RefreshError: If the token endpoint returned - an error. + Tuple(bool, Mapping[str, str]): A boolean indicating if the request is + successful, and a mapping for the JSON-decoded response data. """ - body = urllib.parse.urlencode(body).encode("utf-8") - headers = {"content-type": _URLENCODED_CONTENT_TYPE} + if use_json: + headers = {"Content-Type": "application/json"} + json_body = body + body = None + else: + headers = {"content-type": _URLENCODED_CONTENT_TYPE} + json_body = None + body = urllib.parse.urlencode(body).encode("utf-8") + + if access_token: + headers["Authorization"] = "Bearer {}".format(access_token) retry = 0 # retry to fetch token for maximum of two times if any internal failure # occurs. while True: - response = request(method="POST", url=token_uri, headers=headers, body=body) + response = request( + method="POST", url=token_uri, headers=headers, body=body, json=json_body + ) response_body = ( response.data.decode("utf-8") if hasattr(response.data, "decode") @@ -121,8 +134,38 @@ def _token_endpoint_request(request, token_uri, body): ): retry += 1 continue - _handle_error_response(response_body) + return response.status == http_client.OK, response_data + + return response.status == http_client.OK, response_data + + +def _token_endpoint_request( + request, token_uri, body, access_token=None, use_json=False +): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + access_token (Optional(str)): The access token needed to make the request. + use_json (Optional(bool)): Use urlencoded format or json format for the + content type. The default value is False. + Returns: + Mapping[str, str]: The JSON-decoded response data. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + response_status_ok, response_data = _token_endpoint_request_no_throw( + request, token_uri, body, access_token=access_token, use_json=use_json + ) + if not response_status_ok: + _handle_error_response(response_data) return response_data @@ -204,8 +247,43 @@ def id_token_jwt_grant(request, token_uri, assertion): return id_token, expiry, response_data +def _handle_refresh_grant_response(response_data, refresh_token): + """Extract tokens from refresh grant response. + + Args: + response_data (Mapping[str, str]): Refresh grant response data. + refresh_token (str): Current refresh token. + + Returns: + Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token, + refresh token, expiration, and additional data returned by the token + endpoint. If response_data doesn't have refresh token, then the current + refresh token will be returned. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No access token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + refresh_token = response_data.get("refresh_token", refresh_token) + expiry = _parse_expiry(response_data) + + return access_token, refresh_token, expiry, response_data + + def refresh_grant( - request, token_uri, refresh_token, client_id, client_secret, scopes=None + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, ): """Implements the OAuth 2.0 refresh token grant. @@ -224,10 +302,11 @@ def refresh_grant( scopes must be authorized for the refresh token. Useful if refresh token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The reauth Proof Token. Returns: - Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The - access token, new refresh token, expiration, and additional data + Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access + token, new or current refresh token, expiration, and additional data returned by the token endpoint. Raises: @@ -244,16 +323,8 @@ def refresh_grant( } if scopes: body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token response_data = _token_endpoint_request(request, token_uri, body) - - try: - access_token = response_data["access_token"] - except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No access token in response.", response_data) - six.raise_from(new_exc, caught_exc) - - refresh_token = response_data.get("refresh_token", refresh_token) - expiry = _parse_expiry(response_data) - - return access_token, refresh_token, expiry, response_data + return _handle_refresh_grant_response(response_data, refresh_token) diff --git a/google/oauth2/challenges.py b/google/oauth2/challenges.py new file mode 100644 index 000000000..d41677027 --- /dev/null +++ b/google/oauth2/challenges.py @@ -0,0 +1,159 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Challenges for reauthentication. +""" + +import abc +import base64 +import sys + +import six + +from google.auth import _helpers +from google.auth import exceptions + + +REAUTH_ORIGIN = "https://accounts.google.com" + + +@six.add_metaclass(abc.ABCMeta) +class ReauthChallenge(object): + """Base class for reauth challenges.""" + + @property + @abc.abstractmethod + def name(self): # pragma: NO COVER + """Returns the name of the challenge.""" + pass + + @property + @abc.abstractmethod + def is_locally_eligible(self): # pragma: NO COVER + """Returns true if a challenge is supported locally on this machine.""" + pass + + @abc.abstractmethod + def obtain_challenge_input(self, metadata): # pragma: NO COVER + """Performs logic required to obtain credentials and returns it. + + Args: + metadata: challenge metadata returned in the 'challenges' field in + the initial reauth request. Includes the 'challengeType' field + and other challenge-specific fields. + + Returns: + response that will be send to the reauth service as the content of + the 'proposalResponse' field in the request body. Usually a dict + with the keys specific to the challenge. For example, + {'credential': password} for password challenge. + """ + pass + + +class PasswordChallenge(ReauthChallenge): + """Challenge that asks for user's password.""" + + @property + def name(self): + return "PASSWORD" + + @property + def is_locally_eligible(self): + return True + + def obtain_challenge_input(self, unused_metadata): + passwd = _helpers.get_user_password("Please enter your password:") + if not passwd: + passwd = " " # avoid the server crashing in case of no password :D + return {"credential": passwd} + + +class SecurityKeyChallenge(ReauthChallenge): + """Challenge that asks for user's security key touch.""" + + @property + def name(self): + return "SECURITY_KEY" + + @property + def is_locally_eligible(self): + return True + + def obtain_challenge_input(self, metadata): + try: + import pyu2f.convenience.authenticator + import pyu2f.errors + import pyu2f.model + except ImportError: + raise exceptions.ReauthFailError( + "pyu2f dependency is required to use Security key reauth feature. " + "It can be installed via `pip install pyu2f` or `pip install google-auth[reauth]`." + ) + sk = metadata["securityKey"] + challenges = sk["challenges"] + app_id = sk["applicationId"] + + challenge_data = [] + for c in challenges: + kh = c["keyHandle"].encode("ascii") + key = pyu2f.model.RegisteredKey(bytearray(base64.urlsafe_b64decode(kh))) + challenge = c["challenge"].encode("ascii") + challenge = base64.urlsafe_b64decode(challenge) + challenge_data.append({"key": key, "challenge": challenge}) + + try: + api = pyu2f.convenience.authenticator.CreateCompositeAuthenticator( + REAUTH_ORIGIN + ) + response = api.Authenticate( + app_id, challenge_data, print_callback=sys.stderr.write + ) + return {"securityKey": response} + except pyu2f.errors.U2FError as e: + if e.code == pyu2f.errors.U2FError.DEVICE_INELIGIBLE: + sys.stderr.write("Ineligible security key.\n") + elif e.code == pyu2f.errors.U2FError.TIMEOUT: + sys.stderr.write("Timed out while waiting for security key touch.\n") + else: + raise e + except pyu2f.errors.NoDeviceFoundError: + sys.stderr.write("No security key found.\n") + return None + + +class SamlChallenge(ReauthChallenge): + """Challenge that asks the users to browse to their ID Providers.""" + + @property + def name(self): + return "SAML" + + @property + def is_locally_eligible(self): + return True + + def obtain_challenge_input(self, metadata): + # Magic Arch has not fully supported returning a proper dedirect URL + # for programmatic SAML users today. So we error our here and request + # users to complete a web login. + raise exceptions.ReauthFailError( + "SAML login is required for the current account to complete reauthentication." + ) + + +AVAILABLE_CHALLENGES = { + challenge.name: challenge + for challenge in [SecurityKeyChallenge(), PasswordChallenge(), SamlChallenge()] +} diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index 464cc4878..dcfa5f912 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -41,7 +41,7 @@ from google.auth import _helpers from google.auth import credentials from google.auth import exceptions -from google.oauth2 import _client +from google.oauth2 import reauth # The Google OAuth 2.0 token endpoint. Used for authorized user credentials. @@ -55,6 +55,10 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr quota project, use :meth:`with_quota_project` or :: credentials = credentials.with_quota_project('myproject-123) + + If reauth is enabled, `pyu2f` dependency has to be installed in order to use security + key reauth feature. Dependency can be installed via `pip install pyu2f` or `pip install + google-auth[reauth]`. """ def __init__( @@ -69,6 +73,7 @@ def __init__( default_scopes=None, quota_project_id=None, expiry=None, + rapt_token=None, ): """ Args: @@ -97,6 +102,7 @@ def __init__( quota_project_id (Optional[str]): The project ID used for quota and billing. This project may be different from the project used to create the credentials. + rapt_token (Optional[str]): The reauth Proof Token. """ super(Credentials, self).__init__() self.token = token @@ -109,6 +115,7 @@ def __init__( self._client_id = client_id self._client_secret = client_secret self._quota_project_id = quota_project_id + self._rapt_token = rapt_token def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called @@ -130,6 +137,7 @@ def __setstate__(self, d): self._client_id = d.get("_client_id") self._client_secret = d.get("_client_secret") self._quota_project_id = d.get("_quota_project_id") + self._rapt_token = d.get("_rapt_token") @property def refresh_token(self): @@ -174,6 +182,11 @@ def requires_scopes(self): the initial token is requested and can not be changed.""" return False + @property + def rapt_token(self): + """Optional[str]: The reauth Proof Token.""" + return self._rapt_token + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): @@ -187,6 +200,7 @@ def with_quota_project(self, quota_project_id): scopes=self.scopes, default_scopes=self.default_scopes, quota_project_id=quota_project_id, + rapt_token=self.rapt_token, ) @_helpers.copy_docstring(credentials.Credentials) @@ -205,23 +219,31 @@ def refresh(self, request): scopes = self._scopes if self._scopes is not None else self._default_scopes - access_token, refresh_token, expiry, grant_response = _client.refresh_grant( + ( + access_token, + refresh_token, + expiry, + grant_response, + rapt_token, + ) = reauth.refresh_grant( request, self._token_uri, self._refresh_token, self._client_id, self._client_secret, - scopes, + scopes=scopes, + rapt_token=self._rapt_token, ) self.token = access_token self.expiry = expiry self._refresh_token = refresh_token self._id_token = grant_response.get("id_token") + self._rapt_token = rapt_token - if scopes and "scopes" in grant_response: + if scopes and "scope" in grant_response: requested_scopes = frozenset(scopes) - granted_scopes = frozenset(grant_response["scopes"].split()) + granted_scopes = frozenset(grant_response["scope"].split()) scopes_requested_but_not_granted = requested_scopes - granted_scopes if scopes_requested_but_not_granted: raise exceptions.RefreshError( @@ -323,6 +345,7 @@ def to_json(self, strip=None): "client_id": self.client_id, "client_secret": self.client_secret, "scopes": self.scopes, + "rapt_token": self.rapt_token, } if self.expiry: # flatten expiry timestamp prep["expiry"] = self.expiry.isoformat() + "Z" diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py new file mode 100644 index 000000000..517a15994 --- /dev/null +++ b/google/oauth2/reauth.py @@ -0,0 +1,329 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module that provides functions for handling rapt authentication. + +Reauth is a process of obtaining additional authentication (such as password, +security token, etc.) while refreshing OAuth 2.0 credentials for a user. + +Credentials that use the Reauth flow must have the reauth scope, +``https://www.googleapis.com/auth/accounts.reauth``. + +This module provides a high-level function for executing the Reauth process, +:func:`refresh_grant`, and lower-level helpers for doing the individual +steps of the reauth process. + +Those steps are: + +1. Obtaining a list of challenges from the reauth server. +2. Running through each challenge and sending the result back to the reauth + server. +3. Refreshing the access token using the returned rapt token. +""" + +import sys + +from six.moves import range + +from google.auth import _helpers +from google.auth import exceptions +from google.oauth2 import _client +from google.oauth2 import challenges + + +_REAUTH_SCOPE = "https://www.googleapis.com/auth/accounts.reauth" +_REAUTH_API = "https://reauth.googleapis.com/v2/sessions" + +_REAUTH_NEEDED_ERROR = "invalid_grant" +_REAUTH_NEEDED_ERROR_INVALID_RAPT = "invalid_rapt" +_REAUTH_NEEDED_ERROR_RAPT_REQUIRED = "rapt_required" + +_AUTHENTICATED = "AUTHENTICATED" +_CHALLENGE_REQUIRED = "CHALLENGE_REQUIRED" +_CHALLENGE_PENDING = "CHALLENGE_PENDING" + + +def _get_challenges( + request, supported_challenge_types, access_token, requested_scopes=None +): + """Does initial request to reauth API to get the challenges. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + supported_challenge_types (Sequence[str]): list of challenge names + supported by the manager. + access_token (str): Access token with reauth scopes. + requested_scopes (Optional(Sequence[str])): Authorized scopes for the credentials. + + Returns: + dict: The response from the reauth API. + """ + body = {"supportedChallengeTypes": supported_challenge_types} + if requested_scopes: + body["oauthScopesForDomainPolicyLookup"] = requested_scopes + + return _client._token_endpoint_request( + request, _REAUTH_API + ":start", body, access_token=access_token, use_json=True + ) + + +def _send_challenge_result( + request, session_id, challenge_id, client_input, access_token +): + """Attempt to refresh access token by sending next challenge result. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + session_id (str): session id returned by the initial reauth call. + challenge_id (str): challenge id returned by the initial reauth call. + client_input: dict with a challenge-specific client input. For example: + ``{'credential': password}`` for password challenge. + access_token (str): Access token with reauth scopes. + + Returns: + dict: The response from the reauth API. + """ + body = { + "sessionId": session_id, + "challengeId": challenge_id, + "action": "RESPOND", + "proposalResponse": client_input, + } + + return _client._token_endpoint_request( + request, + _REAUTH_API + "/{}:continue".format(session_id), + body, + access_token=access_token, + use_json=True, + ) + + +def _run_next_challenge(msg, request, access_token): + """Get the next challenge from msg and run it. + + Args: + msg (dict): Reauth API response body (either from the initial request to + https://reauth.googleapis.com/v2/sessions:start or from sending the + previous challenge response to + https://reauth.googleapis.com/v2/sessions/id:continue) + request (google.auth.transport.Request): A callable used to make + HTTP requests. + access_token (str): reauth access token + + Returns: + dict: The response from the reauth API. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed. + """ + for challenge in msg["challenges"]: + if challenge["status"] != "READY": + # Skip non-activated challenges. + continue + c = challenges.AVAILABLE_CHALLENGES.get(challenge["challengeType"], None) + if not c: + raise exceptions.ReauthFailError( + "Unsupported challenge type {0}. Supported types: {1}".format( + challenge["challengeType"], + ",".join(list(challenges.AVAILABLE_CHALLENGES.keys())), + ) + ) + if not c.is_locally_eligible: + raise exceptions.ReauthFailError( + "Challenge {0} is not locally eligible".format( + challenge["challengeType"] + ) + ) + client_input = c.obtain_challenge_input(challenge) + if not client_input: + return None + return _send_challenge_result( + request, + msg["sessionId"], + challenge["challengeId"], + client_input, + access_token, + ) + return None + + +def _obtain_rapt(request, access_token, requested_scopes, rounds_num=5): + """Given an http request method and reauth access token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + access_token (str): reauth access token + requested_scopes (Sequence[str]): scopes required by the client application + rounds_num (Optional(int)): max number of attempts to get a rapt after the next + challenge, before failing the reauth. This defines total number of + challenges + number of additional retries if the chalenge input + wasn't accepted. + + Returns: + str: The rapt token. + + Raises: + google.auth.exceptions.ReauthError: if reauth failed + """ + msg = None + + for _ in range(0, rounds_num): + + if not msg: + msg = _get_challenges( + request, + list(challenges.AVAILABLE_CHALLENGES.keys()), + access_token, + requested_scopes, + ) + + if msg["status"] == _AUTHENTICATED: + return msg["encodedProofOfReauthToken"] + + if not ( + msg["status"] == _CHALLENGE_REQUIRED or msg["status"] == _CHALLENGE_PENDING + ): + raise exceptions.ReauthFailError( + "Reauthentication challenge failed due to API error: {}".format( + msg["status"] + ) + ) + + """Check if we are in an interractive environment. + + If the rapt token needs refreshing, the user needs to answer the + challenges. + """ + if not _helpers.is_interactive(): + raise exceptions.ReauthFailError( + "Reauthentication challenge could not be answered because you are not in an interactive session." + ) + + msg = _run_next_challenge(msg, request, access_token) + + # If we got here it means we didn't get authenticated. + raise exceptions.ReauthFailError() + + +def get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=None +): + """Given an http request method and refresh_token, get rapt token. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + client_id (str): client id to get access token for reauth scope. + client_secret (str): client secret for the client_id + refresh_token (str): refresh token to refresh access token + token_uri (str): uri to refresh access token + scopes (Optional(Sequence[str])): scopes required by the client application + + Returns: + str: The rapt token. + Raises: + google.auth.exceptions.RefreshError: If reauth failed. + """ + sys.stderr.write("Reauthentication required.\n") + + # Get access token for reauth. + access_token, _, _, _ = _client.refresh_grant( + request=request, + client_id=client_id, + client_secret=client_secret, + refresh_token=refresh_token, + token_uri=token_uri, + scopes=[_REAUTH_SCOPE], + ) + + # Get rapt token from reauth API. + rapt_token = _obtain_rapt(request, access_token, requested_scopes=scopes) + + return rapt_token + + +def refresh_grant( + request, + token_uri, + refresh_token, + client_id, + client_secret, + scopes=None, + rapt_token=None, +): + """Implements the reauthentication flow. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + rapt_token (Optional(str)): The rapt token for reauth. + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The + access token, new refresh token, expiration, and additional data + returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = { + "grant_type": _client._REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + if rapt_token: + body["rapt"] = rapt_token + + response_status_ok, response_data = _client._token_endpoint_request_no_throw( + request, token_uri, body + ) + if ( + not response_status_ok + and response_data.get("error") == _REAUTH_NEEDED_ERROR + and ( + response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_INVALID_RAPT + or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED + ) + ): + rapt_token = get_rapt_token( + request, client_id, client_secret, refresh_token, token_uri, scopes=scopes + ) + body["rapt"] = rapt_token + (response_status_ok, response_data) = _client._token_endpoint_request_no_throw( + request, token_uri, body + ) + + if not response_status_ok: + _client._handle_error_response(response_data) + return _client._handle_refresh_grant_response(response_data, refresh_token) + ( + rapt_token, + ) diff --git a/noxfile.py b/noxfile.py index 3b4863c2d..0bd7f6c6c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -25,6 +25,7 @@ "pytest", "pytest-cov", "pytest-localserver", + "pyu2f", "requests", "urllib3", "cryptography", diff --git a/setup.py b/setup.py index 16ba98cfd..ef723f8af 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ extras = { "aiohttp": "aiohttp >= 3.6.2, < 4.0.0dev; python_version>='3.6'", "pyopenssl": "pyopenssl>=20.0.0", + "reauth": "pyu2f>=0.1.5", } with io.open("README.rst", "r") as fh: diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index c3ae2af98..907b6dd1c 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -48,7 +48,7 @@ def test__handle_error_response(): - response_data = json.dumps({"error": "help", "error_description": "I'm alive"}) + response_data = {"error": "help", "error_description": "I'm alive"} with pytest.raises(exceptions.RefreshError) as excinfo: _client._handle_error_response(response_data) @@ -57,12 +57,12 @@ def test__handle_error_response(): def test__handle_error_response_non_json(): - response_data = "Help, I'm alive" + response_data = {"foo": "bar"} with pytest.raises(exceptions.RefreshError) as excinfo: _client._handle_error_response(response_data) - assert excinfo.match(r"Help, I\'m alive") + assert excinfo.match(r"{\"foo\": \"bar\"}") @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) @@ -97,6 +97,34 @@ def test__token_endpoint_request(): url="http://example.com", headers={"content-type": "application/x-www-form-urlencoded"}, body="test=params".encode("utf-8"), + json=None, + ) + + # Check result + assert result == {"test": "response"} + + +def test__token_endpoint_request_use_json(): + request = make_request({"test": "response"}) + + result = _client._token_endpoint_request( + request, + "http://example.com", + {"test": "params"}, + access_token="access_token", + use_json=True, + ) + + # Check request call + request.assert_called_with( + method="POST", + url="http://example.com", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer access_token", + }, + body=None, + json={"test": "params"}, ) # Check result @@ -220,7 +248,12 @@ def test_refresh_grant(unused_utcnow): ) token, refresh_token, expiry, extra_data = _client.refresh_grant( - request, "http://example.com", "refresh_token", "client_id", "client_secret" + request, + "http://example.com", + "refresh_token", + "client_id", + "client_secret", + rapt_token="rapt_token", ) # Check request call @@ -231,6 +264,7 @@ def test_refresh_grant(unused_utcnow): "refresh_token": "refresh_token", "client_id": "client_id", "client_secret": "client_secret", + "rapt": "rapt_token", }, ) diff --git a/tests/oauth2/test_challenges.py b/tests/oauth2/test_challenges.py new file mode 100644 index 000000000..854266401 --- /dev/null +++ b/tests/oauth2/test_challenges.py @@ -0,0 +1,141 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the reauth module.""" + +import base64 +import sys + +import mock +import pytest +import pyu2f + +from google.auth import exceptions +from google.oauth2 import challenges + + +def test_security_key(): + metadata = { + "status": "READY", + "challengeId": 2, + "challengeType": "SECURITY_KEY", + "securityKey": { + "applicationId": "security_key_application_id", + "challenges": [ + { + "keyHandle": "some_key", + "challenge": base64.urlsafe_b64encode( + "some_challenge".encode("ascii") + ).decode("ascii"), + } + ], + }, + } + mock_key = mock.Mock() + + challenge = challenges.SecurityKeyChallenge() + + # Test the case that security key challenge is passed. + with mock.patch("pyu2f.model.RegisteredKey", return_value=mock_key): + with mock.patch( + "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate" + ) as mock_authenticate: + mock_authenticate.return_value = "security key response" + assert challenge.name == "SECURITY_KEY" + assert challenge.is_locally_eligible + assert challenge.obtain_challenge_input(metadata) == { + "securityKey": "security key response" + } + mock_authenticate.assert_called_with( + "security_key_application_id", + [{"key": mock_key, "challenge": b"some_challenge"}], + print_callback=sys.stderr.write, + ) + + # Test various types of exceptions. + with mock.patch("pyu2f.model.RegisteredKey", return_value=mock_key): + with mock.patch( + "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate" + ) as mock_authenticate: + mock_authenticate.side_effect = pyu2f.errors.U2FError( + pyu2f.errors.U2FError.DEVICE_INELIGIBLE + ) + assert challenge.obtain_challenge_input(metadata) is None + + with mock.patch( + "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate" + ) as mock_authenticate: + mock_authenticate.side_effect = pyu2f.errors.U2FError( + pyu2f.errors.U2FError.TIMEOUT + ) + assert challenge.obtain_challenge_input(metadata) is None + + with mock.patch( + "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate" + ) as mock_authenticate: + mock_authenticate.side_effect = pyu2f.errors.U2FError( + pyu2f.errors.U2FError.BAD_REQUEST + ) + with pytest.raises(pyu2f.errors.U2FError): + challenge.obtain_challenge_input(metadata) + + with mock.patch( + "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate" + ) as mock_authenticate: + mock_authenticate.side_effect = pyu2f.errors.NoDeviceFoundError() + assert challenge.obtain_challenge_input(metadata) is None + + with mock.patch( + "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate" + ) as mock_authenticate: + mock_authenticate.side_effect = pyu2f.errors.UnsupportedVersionException() + with pytest.raises(pyu2f.errors.UnsupportedVersionException): + challenge.obtain_challenge_input(metadata) + + with mock.patch.dict("sys.modules"): + sys.modules["pyu2f"] = None + with pytest.raises(exceptions.ReauthFailError) as excinfo: + challenge.obtain_challenge_input(metadata) + assert excinfo.match(r"pyu2f dependency is required") + + +@mock.patch("getpass.getpass", return_value="foo") +def test_password_challenge(getpass_mock): + challenge = challenges.PasswordChallenge() + + with mock.patch("getpass.getpass", return_value="foo"): + assert challenge.is_locally_eligible + assert challenge.name == "PASSWORD" + assert challenges.PasswordChallenge().obtain_challenge_input({}) == { + "credential": "foo" + } + + with mock.patch("getpass.getpass", return_value=None): + assert challenges.PasswordChallenge().obtain_challenge_input({}) == { + "credential": " " + } + + +def test_saml_challenge(): + metadata = { + "status": "READY", + "challengeId": 1, + "challengeType": "SAML", + "securityKey": {}, + } + challenge = challenges.SamlChallenge() + assert challenge.is_locally_eligible + assert challenge.name == "SAML" + with pytest.raises(exceptions.ReauthFailError): + challenge.obtain_challenge_input(metadata) diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index b885d2973..4a387a58e 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -38,6 +38,7 @@ class TestCredentials(object): TOKEN_URI = "https://example.com/oauth2/token" REFRESH_TOKEN = "refresh_token" + RAPT_TOKEN = "rapt_token" CLIENT_ID = "client_id" CLIENT_SECRET = "client_secret" @@ -49,6 +50,7 @@ def make_credentials(cls): token_uri=cls.TOKEN_URI, client_id=cls.CLIENT_ID, client_secret=cls.CLIENT_SECRET, + rapt_token=cls.RAPT_TOKEN, ) def test_default_state(self): @@ -63,14 +65,16 @@ def test_default_state(self): assert credentials.token_uri == self.TOKEN_URI assert credentials.client_id == self.CLIENT_ID assert credentials.client_secret == self.CLIENT_SECRET + assert credentials.rapt_token == self.RAPT_TOKEN - @mock.patch("google.oauth2._client.refresh_grant", autospec=True) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, ) def test_refresh_success(self, unused_utcnow, refresh_grant): token = "token" + new_rapt_token = "new_rapt_token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = {"id_token": mock.sentinel.id_token} refresh_grant.return_value = ( @@ -82,6 +86,8 @@ def test_refresh_success(self, unused_utcnow, refresh_grant): expiry, # Extra data grant_response, + # rapt_token + new_rapt_token, ) request = mock.create_autospec(transport.Request) @@ -98,12 +104,14 @@ def test_refresh_success(self, unused_utcnow, refresh_grant): self.CLIENT_ID, self.CLIENT_SECRET, None, + self.RAPT_TOKEN, ) # Check that the credentials have the token and expiry assert credentials.token == token assert credentials.expiry == expiry assert credentials.id_token == mock.sentinel.id_token + assert credentials.rapt_token == new_rapt_token # Check that the credentials are valid (have a token and are not # expired) @@ -118,7 +126,7 @@ def test_refresh_no_refresh_token(self): request.assert_not_called() - @mock.patch("google.oauth2._client.refresh_grant", autospec=True) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -129,8 +137,9 @@ def test_credentials_with_scopes_requested_refresh_success( scopes = ["email", "profile"] default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] token = "token" + new_rapt_token = "new_rapt_token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) - grant_response = {"id_token": mock.sentinel.id_token} + grant_response = {"id_token": mock.sentinel.id_token, "scope": "email profile"} refresh_grant.return_value = ( # Access token token, @@ -140,6 +149,8 @@ def test_credentials_with_scopes_requested_refresh_success( expiry, # Extra data grant_response, + # rapt token + new_rapt_token, ) request = mock.create_autospec(transport.Request) @@ -151,6 +162,7 @@ def test_credentials_with_scopes_requested_refresh_success( client_secret=self.CLIENT_SECRET, scopes=scopes, default_scopes=default_scopes, + rapt_token=self.RAPT_TOKEN, ) # Refresh credentials @@ -164,6 +176,7 @@ def test_credentials_with_scopes_requested_refresh_success( self.CLIENT_ID, self.CLIENT_SECRET, scopes, + self.RAPT_TOKEN, ) # Check that the credentials have the token and expiry @@ -171,12 +184,13 @@ def test_credentials_with_scopes_requested_refresh_success( assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token assert creds.has_scopes(scopes) + assert creds.rapt_token == new_rapt_token # Check that the credentials are valid (have a token and are not # expired.) assert creds.valid - @mock.patch("google.oauth2._client.refresh_grant", autospec=True) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -186,6 +200,7 @@ def test_credentials_with_only_default_scopes_requested( ): default_scopes = ["email", "profile"] token = "token" + new_rapt_token = "new_rapt_token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = {"id_token": mock.sentinel.id_token} refresh_grant.return_value = ( @@ -197,6 +212,8 @@ def test_credentials_with_only_default_scopes_requested( expiry, # Extra data grant_response, + # rapt token + new_rapt_token, ) request = mock.create_autospec(transport.Request) @@ -207,6 +224,7 @@ def test_credentials_with_only_default_scopes_requested( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, default_scopes=default_scopes, + rapt_token=self.RAPT_TOKEN, ) # Refresh credentials @@ -220,6 +238,7 @@ def test_credentials_with_only_default_scopes_requested( self.CLIENT_ID, self.CLIENT_SECRET, default_scopes, + self.RAPT_TOKEN, ) # Check that the credentials have the token and expiry @@ -227,12 +246,13 @@ def test_credentials_with_only_default_scopes_requested( assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token assert creds.has_scopes(default_scopes) + assert creds.rapt_token == new_rapt_token # Check that the credentials are valid (have a token and are not # expired.) assert creds.valid - @mock.patch("google.oauth2._client.refresh_grant", autospec=True) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -242,6 +262,7 @@ def test_credentials_with_scopes_returned_refresh_success( ): scopes = ["email", "profile"] token = "token" + new_rapt_token = "new_rapt_token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = { "id_token": mock.sentinel.id_token, @@ -256,6 +277,8 @@ def test_credentials_with_scopes_returned_refresh_success( expiry, # Extra data grant_response, + # rapt token + new_rapt_token, ) request = mock.create_autospec(transport.Request) @@ -266,6 +289,7 @@ def test_credentials_with_scopes_returned_refresh_success( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, scopes=scopes, + rapt_token=self.RAPT_TOKEN, ) # Refresh credentials @@ -279,6 +303,7 @@ def test_credentials_with_scopes_returned_refresh_success( self.CLIENT_ID, self.CLIENT_SECRET, scopes, + self.RAPT_TOKEN, ) # Check that the credentials have the token and expiry @@ -286,12 +311,13 @@ def test_credentials_with_scopes_returned_refresh_success( assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token assert creds.has_scopes(scopes) + assert creds.rapt_token == new_rapt_token # Check that the credentials are valid (have a token and are not # expired.) assert creds.valid - @mock.patch("google.oauth2._client.refresh_grant", autospec=True) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, @@ -302,10 +328,11 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( scopes = ["email", "profile"] scopes_returned = ["email"] token = "token" + new_rapt_token = "new_rapt_token" expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) grant_response = { "id_token": mock.sentinel.id_token, - "scopes": " ".join(scopes_returned), + "scope": " ".join(scopes_returned), } refresh_grant.return_value = ( # Access token @@ -316,6 +343,8 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( expiry, # Extra data grant_response, + # rapt token + new_rapt_token, ) request = mock.create_autospec(transport.Request) @@ -326,6 +355,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, scopes=scopes, + rapt_token=self.RAPT_TOKEN, ) # Refresh credentials @@ -342,6 +372,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( self.CLIENT_ID, self.CLIENT_SECRET, scopes, + self.RAPT_TOKEN, ) # Check that the credentials have the token and expiry @@ -349,6 +380,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( assert creds.expiry == expiry assert creds.id_token == mock.sentinel.id_token assert creds.has_scopes(scopes) + assert creds.rapt_token == new_rapt_token # Check that the credentials are valid (have a token and are not # expired.) diff --git a/tests/oauth2/test_reauth.py b/tests/oauth2/test_reauth.py new file mode 100644 index 000000000..4c036e1a0 --- /dev/null +++ b/tests/oauth2/test_reauth.py @@ -0,0 +1,299 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import mock +import pytest + +from google.auth import exceptions +from google.oauth2 import reauth + + +MOCK_REQUEST = mock.Mock() +CHALLENGES_RESPONSE_TEMPLATE = { + "status": "CHALLENGE_REQUIRED", + "sessionId": "123", + "challenges": [ + { + "status": "READY", + "challengeId": 1, + "challengeType": "PASSWORD", + "securityKey": {}, + } + ], +} +CHALLENGES_RESPONSE_AUTHENTICATED = { + "status": "AUTHENTICATED", + "sessionId": "123", + "encodedProofOfReauthToken": "new_rapt_token", +} + + +class MockChallenge(object): + def __init__(self, name, locally_eligible, challenge_input): + self.name = name + self.is_locally_eligible = locally_eligible + self.challenge_input = challenge_input + + def obtain_challenge_input(self, metadata): + return self.challenge_input + + +def test__get_challenges(): + with mock.patch( + "google.oauth2._client._token_endpoint_request" + ) as mock_token_endpoint_request: + reauth._get_challenges(MOCK_REQUEST, ["SAML"], "token") + mock_token_endpoint_request.assert_called_with( + MOCK_REQUEST, + reauth._REAUTH_API + ":start", + {"supportedChallengeTypes": ["SAML"]}, + access_token="token", + use_json=True, + ) + + +def test__get_challenges_with_scopes(): + with mock.patch( + "google.oauth2._client._token_endpoint_request" + ) as mock_token_endpoint_request: + reauth._get_challenges( + MOCK_REQUEST, ["SAML"], "token", requested_scopes=["scope"] + ) + mock_token_endpoint_request.assert_called_with( + MOCK_REQUEST, + reauth._REAUTH_API + ":start", + { + "supportedChallengeTypes": ["SAML"], + "oauthScopesForDomainPolicyLookup": ["scope"], + }, + access_token="token", + use_json=True, + ) + + +def test__send_challenge_result(): + with mock.patch( + "google.oauth2._client._token_endpoint_request" + ) as mock_token_endpoint_request: + reauth._send_challenge_result( + MOCK_REQUEST, "123", "1", {"credential": "password"}, "token" + ) + mock_token_endpoint_request.assert_called_with( + MOCK_REQUEST, + reauth._REAUTH_API + "/123:continue", + { + "sessionId": "123", + "challengeId": "1", + "action": "RESPOND", + "proposalResponse": {"credential": "password"}, + }, + access_token="token", + use_json=True, + ) + + +def test__run_next_challenge_not_ready(): + challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE) + challenges_response["challenges"][0]["status"] = "STATUS_UNSPECIFIED" + assert ( + reauth._run_next_challenge(challenges_response, MOCK_REQUEST, "token") is None + ) + + +def test__run_next_challenge_not_supported(): + challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE) + challenges_response["challenges"][0]["challengeType"] = "CHALLENGE_TYPE_UNSPECIFIED" + with pytest.raises(exceptions.ReauthFailError) as excinfo: + reauth._run_next_challenge(challenges_response, MOCK_REQUEST, "token") + assert excinfo.match(r"Unsupported challenge type CHALLENGE_TYPE_UNSPECIFIED") + + +def test__run_next_challenge_not_locally_eligible(): + mock_challenge = MockChallenge("PASSWORD", False, "challenge_input") + with mock.patch( + "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge} + ): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + reauth._run_next_challenge( + CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token" + ) + assert excinfo.match(r"Challenge PASSWORD is not locally eligible") + + +def test__run_next_challenge_no_challenge_input(): + mock_challenge = MockChallenge("PASSWORD", True, None) + with mock.patch( + "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge} + ): + assert ( + reauth._run_next_challenge( + CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token" + ) + is None + ) + + +def test__run_next_challenge_success(): + mock_challenge = MockChallenge("PASSWORD", True, {"credential": "password"}) + with mock.patch( + "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge} + ): + with mock.patch( + "google.oauth2.reauth._send_challenge_result" + ) as mock_send_challenge_result: + reauth._run_next_challenge( + CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token" + ) + mock_send_challenge_result.assert_called_with( + MOCK_REQUEST, "123", 1, {"credential": "password"}, "token" + ) + + +def test__obtain_rapt_not_authenticated(): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + reauth._obtain_rapt(MOCK_REQUEST, "token", None, rounds_num=0) + assert excinfo.match(r"Reauthentication failed. None") + + +def test__obtain_rapt_authenticated(): + with mock.patch( + "google.oauth2.reauth._get_challenges", + return_value=CHALLENGES_RESPONSE_AUTHENTICATED, + ): + assert ( + reauth._obtain_rapt(MOCK_REQUEST, "token", None, rounds_num=1) + == "new_rapt_token" + ) + + +def test__obtain_rapt_unsupported_status(): + challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE) + challenges_response["status"] = "STATUS_UNSPECIFIED" + with mock.patch( + "google.oauth2.reauth._get_challenges", return_value=challenges_response + ): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + reauth._obtain_rapt(MOCK_REQUEST, "token", None, rounds_num=1) + assert excinfo.match(r"API error: STATUS_UNSPECIFIED") + + +def test__obtain_rapt_not_interactive(): + with mock.patch( + "google.oauth2.reauth._get_challenges", + return_value=CHALLENGES_RESPONSE_TEMPLATE, + ): + with mock.patch("google.auth._helpers.is_interactive", return_value=False): + with pytest.raises(exceptions.ReauthFailError) as excinfo: + reauth._obtain_rapt(MOCK_REQUEST, "token", None, rounds_num=1) + assert excinfo.match(r"not in an interactive session") + + +def test__obtain_rapt_run_next_challenge(): + with mock.patch( + "google.oauth2.reauth._get_challenges", + return_value=CHALLENGES_RESPONSE_TEMPLATE, + ): + with mock.patch( + "google.oauth2.reauth._run_next_challenge", + return_value=CHALLENGES_RESPONSE_AUTHENTICATED, + ): + with mock.patch("google.auth._helpers.is_interactive", return_value=True): + assert ( + reauth._obtain_rapt(MOCK_REQUEST, "token", None, rounds_num=2) + == "new_rapt_token" + ) + + +def test_get_rapt_token(): + with mock.patch( + "google.oauth2._client.refresh_grant", return_value=("token", None, None, None) + ) as mock_refresh_grant: + with mock.patch( + "google.oauth2.reauth._obtain_rapt", return_value="new_rapt_token" + ) as mock_obtain_rapt: + assert ( + reauth.get_rapt_token( + MOCK_REQUEST, + "client_id", + "client_secret", + "refresh_token", + "token_uri", + ) + == "new_rapt_token" + ) + mock_refresh_grant.assert_called_with( + request=MOCK_REQUEST, + client_id="client_id", + client_secret="client_secret", + refresh_token="refresh_token", + token_uri="token_uri", + scopes=[reauth._REAUTH_SCOPE], + ) + mock_obtain_rapt.assert_called_with( + MOCK_REQUEST, "token", requested_scopes=None + ) + + +def test_refresh_grant_failed(): + with mock.patch( + "google.oauth2._client._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.return_value = (False, {"error": "Bad request"}) + with pytest.raises(exceptions.RefreshError) as excinfo: + reauth.refresh_grant( + MOCK_REQUEST, + "token_uri", + "refresh_token", + "client_id", + "client_secret", + scopes=["foo", "bar"], + rapt_token="rapt_token", + ) + assert excinfo.match(r"Bad request") + mock_token_request.assert_called_with( + MOCK_REQUEST, + "token_uri", + { + "grant_type": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + "refresh_token": "refresh_token", + "scope": "foo bar", + "rapt": "rapt_token", + }, + ) + + +def test_refresh_grant_success(): + with mock.patch( + "google.oauth2._client._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.side_effect = [ + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), + (True, {"access_token": "access_token"}), + ] + with mock.patch( + "google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token" + ): + assert reauth.refresh_grant( + MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + ) == ( + "access_token", + "refresh_token", + None, + {"access_token": "access_token"}, + "new_rapt_token", + ) diff --git a/tests/test__helpers.py b/tests/test__helpers.py index 0c0bad2d2..69f034e62 100644 --- a/tests/test__helpers.py +++ b/tests/test__helpers.py @@ -14,6 +14,7 @@ import datetime +import mock import pytest from six.moves import urllib @@ -168,3 +169,13 @@ def test_unpadded_urlsafe_b64encode(): for case, expected in cases: assert _helpers.unpadded_urlsafe_b64encode(case) == expected + + +def test_get_user_password(): + with mock.patch("getpass.getpass", return_value="foo"): + assert _helpers.get_user_password("") == "foo" + + +def test_is_interactive(): + with mock.patch("sys.stdin.isatty", return_value=True): + assert _helpers.is_interactive()