Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add experimental GDCH support #1022

Merged
merged 5 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions google/auth/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
_SERVICE_ACCOUNT_TYPE = "service_account"
_EXTERNAL_ACCOUNT_TYPE = "external_account"
_IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account"
_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account"
_VALID_TYPES = (
_AUTHORIZED_USER_TYPE,
_SERVICE_ACCOUNT_TYPE,
_EXTERNAL_ACCOUNT_TYPE,
_IMPERSONATED_SERVICE_ACCOUNT_TYPE,
_GDCH_SERVICE_ACCOUNT_TYPE,
)

# Help message when no credentials can be found.
Expand Down Expand Up @@ -158,6 +160,8 @@ def _load_credentials_from_info(
credentials, project_id = _get_impersonated_service_account_credentials(
filename, info, scopes
)
elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE:
credentials, project_id = _get_gdch_service_account_credentials(info)
else:
raise exceptions.DefaultCredentialsError(
"The file {file} does not have a valid type. "
Expand Down Expand Up @@ -421,6 +425,36 @@ def _get_impersonated_service_account_credentials(filename, info, scopes):
return credentials, None


def _get_gdch_service_account_credentials(info):
from google.oauth2 import gdch_credentials

k8s_ca_cert_path = info.get("k8s_ca_cert_path")
k8s_cert_path = info.get("k8s_cert_path")
k8s_key_path = info.get("k8s_key_path")
k8s_token_endpoint = info.get("k8s_token_endpoint")
ais_ca_cert_path = info.get("ais_ca_cert_path")
ais_token_endpoint = info.get("ais_token_endpoint")

format_version = info.get("format_version")
if format_version != "v1":
raise exceptions.DefaultCredentialsError(
"format_version is not provided or unsupported. Supported version is: v1"
)

return (
gdch_credentials.ServiceAccountCredentials(
k8s_ca_cert_path,
k8s_cert_path,
k8s_key_path,
k8s_token_endpoint,
ais_ca_cert_path,
ais_token_endpoint,
None,
),
None,
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
)


def _apply_quota_project_id(credentials, quota_project_id):
if quota_project_id:
credentials = credentials.with_quota_project(quota_project_id)
Expand Down Expand Up @@ -456,6 +490,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
endpoint.
The project ID returned in this case is the one corresponding to the
underlying workload identity pool resource if determinable.

If the environment variable is set to the path of a valid GDCH service
account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
credential will be returned. The project ID returned is None unless it
is set via `GOOGLE_CLOUD_PROJECT` environment variable.
2. If the `Google Cloud SDK`_ is installed and has application default
credentials set they are loaded and returned.

Expand Down Expand Up @@ -490,6 +529,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non
.. _Metadata Service: https://cloud.google.com/compute/docs\
/storing-retrieving-metadata
.. _Cloud Run: https://cloud.google.com/run
.. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\
/hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted

Example::

Expand Down
72 changes: 54 additions & 18 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ def _handle_error_response(response_data):
"""Translates an error response into an exception.

Args:
response_data (Mapping): The decoded response data.
response_data (Mapping | str): The decoded response data.

Raises:
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
if isinstance(response_data, six.string_types):
raise exceptions.RefreshError(response_data)
try:
error_details = "{}: {}".format(
response_data["error"], response_data.get("error_description")
Expand Down Expand Up @@ -79,7 +81,13 @@ def _parse_expiry(response_data):


def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
request,
token_uri,
body,
access_token=None,
use_json=False,
expected_status_code=http_client.OK,
**kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Expand All @@ -93,6 +101,10 @@ def _token_endpoint_request_no_throw(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
expected_status_code (Optional(int)): The expected the status code of
the token response. The default value is 200. We may expect other
status code like 201 for GDCH credentials.
kwargs: Additional arguments passed on to the request method.

Returns:
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
Expand All @@ -112,32 +124,46 @@ def _token_endpoint_request_no_throw(
# retry to fetch token for maximum of two times if any internal failure
# occurs.
while True:
response = request(method="POST", url=token_uri, headers=headers, body=body)
response = request(
method="POST", url=token_uri, headers=headers, body=body, **kwargs
)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)
response_data = json.loads(response_body)

if response.status == http_client.OK:
if response.status == expected_status_code:
# response_body should be a JSON
response_data = json.loads(response_body)
break
else:
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data
# For a failed response, response_body could be a string
try:
response_data = json.loads(response_body)
error_desc = response_data.get("error_description") or ""
error_code = response_data.get("error") or ""
if (
any(e == "internal_failure" for e in (error_code, error_desc))
and retry < 1
):
retry += 1
continue
except ValueError:
response_data = response_body
return response.status == expected_status_code, response_data
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

return response.status == expected_status_code, response_data


def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
request,
token_uri,
body,
access_token=None,
use_json=False,
expected_status_code=http_client.OK,
**kwargs
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.

Expand All @@ -150,6 +176,10 @@ def _token_endpoint_request(
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
expected_status_code (Optional(int)): The expected the status code of
the token response. The default value is 200. We may expect other
status code like 201 for GDCH credentials.
kwargs: Additional arguments passed on to the request method.

Returns:
Mapping[str, str]: The JSON-decoded response data.
Expand All @@ -159,7 +189,13 @@ def _token_endpoint_request(
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
request,
token_uri,
body,
access_token=access_token,
use_json=use_json,
expected_status_code=expected_status_code,
**kwargs
)
if not response_status_ok:
_handle_error_response(response_data)
Expand Down
Loading