Skip to content

Commit

Permalink
feat: add reauth support to oauth2 credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Apr 7, 2021
1 parent 48e8be3 commit 5a5adfe
Show file tree
Hide file tree
Showing 13 changed files with 1,192 additions and 48 deletions.
34 changes: 34 additions & 0 deletions google/auth/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import base64
import calendar
import datetime
import getpass
import sys

import six
from six.moves import urllib
Expand Down Expand Up @@ -230,3 +232,35 @@ def unpadded_urlsafe_b64encode(value):
Union[str|bytes]: The encoded value
"""
return base64.urlsafe_b64encode(value).rstrip(b"=")


def get_user_password(text):
"""Get password from user.
Override this function with a different logic if you are using this library
outside a CLI.
Args:
text (str): message for the password prompt.
Returns:
str: password string.
"""
return getpass.getpass(text)


def is_interactive():
"""Check if we are in an interractive environment.
Override this function with a different logic if you are using this library
outside a CLI.
If the rapt token needs refreshing, the user needs to answer the challenges.
If the user is not in an interractive environment, the challenges can not
be answered and we just wait for timeout for no reason.
Returns:
bool: True if is interactive environment, False otherwise.
"""

return sys.stdin.isatty()
9 changes: 9 additions & 0 deletions google/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ class ClientCertError(GoogleAuthError):
class OAuthError(GoogleAuthError):
"""Used to indicate an error occurred during an OAuth related HTTP
request."""


class ReauthFailError(RefreshError):
"""An exception for when reauth failed."""

def __init__(self, message=None):
super(ReauthFailError, self).__init__(
"Reauthentication failed. {0}".format(message)
)
135 changes: 103 additions & 32 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,24 @@
_REFRESH_GRANT_TYPE = "refresh_token"


def _handle_error_response(response_body):
""""Translates an error response into an exception.
def _handle_error_response(response_data):
"""Translates an error response into an exception.
Args:
response_body (str): The decoded response data.
response_data (Mapping): The decoded response data.
Raises:
google.auth.exceptions.RefreshError
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
try:
error_data = json.loads(response_body)
error_details = "{}: {}".format(
error_data["error"], error_data.get("error_description")
response_data["error"], response_data.get("error_description")
)
# If no details could be extracted, use the response data.
except (KeyError, ValueError):
error_details = response_body
error_details = json.dumps(response_data)

raise exceptions.RefreshError(error_details, response_body)
raise exceptions.RefreshError(error_details, response_data)


def _parse_expiry(response_data):
Expand All @@ -78,31 +77,45 @@ def _parse_expiry(response_data):
return None


def _token_endpoint_request(request, token_uri, body):
def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
successful, and a mapping for the JSON-decoded response data.
"""
body = urllib.parse.urlencode(body).encode("utf-8")
headers = {"content-type": _URLENCODED_CONTENT_TYPE}
if use_json:
headers = {"Content-Type": "application/json"}
json_body = body
body = None
else:
headers = {"content-type": _URLENCODED_CONTENT_TYPE}
json_body = None
body = urllib.parse.urlencode(body).encode("utf-8")

if access_token:
headers["Authorization"] = "Bearer {}".format(access_token)

retry = 0
# retry to fetch token for maximum of two times if any internal failure
# occurs.
while True:
response = request(method="POST", url=token_uri, headers=headers, body=body)
response = request(
method="POST", url=token_uri, headers=headers, body=body, json=json_body
)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
Expand All @@ -121,8 +134,38 @@ def _token_endpoint_request(request, token_uri, body):
):
retry += 1
continue
_handle_error_response(response_body)
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
)
if not response_status_ok:
_handle_error_response(response_data)
return response_data


Expand Down Expand Up @@ -204,8 +247,43 @@ def id_token_jwt_grant(request, token_uri, assertion):
return id_token, expiry, response_data


def _handle_refresh_grant_response(response_data, refresh_token):
"""Extract tokens from refresh grant response.
Args:
response_data (Mapping[str, str]): Refresh grant response data.
refresh_token (str): Current refresh token.
Returns:
Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token,
refresh token, expiration, and additional data returned by the token
endpoint. If response_data doesn't have refresh token, then the current
refresh token will be returned.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

refresh_token = response_data.get("refresh_token", refresh_token)
expiry = _parse_expiry(response_data)

return access_token, refresh_token, expiry, response_data


def refresh_grant(
request, token_uri, refresh_token, client_id, client_secret, scopes=None
request,
token_uri,
refresh_token,
client_id,
client_secret,
scopes=None,
rapt_token=None,
):
"""Implements the OAuth 2.0 refresh token grant.
Expand All @@ -224,10 +302,11 @@ def refresh_grant(
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The reauth Proof Token.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
access token, new refresh token, expiration, and additional data
Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access
token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
Expand All @@ -244,16 +323,8 @@ def refresh_grant(
}
if scopes:
body["scope"] = " ".join(scopes)
if rapt_token:
body["rapt"] = rapt_token

response_data = _token_endpoint_request(request, token_uri, body)

try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

refresh_token = response_data.get("refresh_token", refresh_token)
expiry = _parse_expiry(response_data)

return access_token, refresh_token, expiry, response_data
return _handle_refresh_grant_response(response_data, refresh_token)
Loading

0 comments on commit 5a5adfe

Please sign in to comment.