Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

feat: add always_use_jwt_access #167

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
branch = True

[report]
fail_under = 100
show_missing = True
omit =
google/cloud/bigquery_datatransfer/__init__.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,9 @@ def schedule_transfer_runs(
for a time range.

"""
warnings.warn(
"schedule_transfer_runs is deprecated", warnings.DeprecationWarning
)
# Create or coerce a protobuf request object.
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.bigquery_datatransfer_v1.types import datatransfer
from google.cloud.bigquery_datatransfer_v1.types import transfer
Expand All @@ -47,8 +48,6 @@
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class DataTransferServiceTransport(abc.ABC):
"""Abstract transport class for DataTransferService."""
Expand All @@ -66,6 +65,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -89,6 +89,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ":" not in host:
Expand Down Expand Up @@ -117,13 +119,20 @@ def __init__(
**scopes_kwargs, quota_project_id=quota_project_id
)

# If the credentials is service account credentials, then always try to use self signed JWT.
if (
always_use_jwt_access
and isinstance(credentials, service_account.Credentials)
and hasattr(service_account.Credentials, "with_always_use_jwt_access")
):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

# TODO(busunkim): These two class methods are in the base transport
# TODO(busunkim): This method is in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-api-core
# and google-auth are increased.
# should be deleted once the minimum required versions of google-auth is increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
Expand All @@ -144,27 +153,6 @@ def _get_scopes_kwargs(

return scopes_kwargs

# TODO: Remove this function once google-api-core >= 1.26.0 is required
@classmethod
def _get_self_signed_jwt_kwargs(
cls, host: str, scopes: Optional[Sequence[str]]
) -> Dict[str, Union[Optional[Sequence[str]], str]]:
"""Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""

self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}

if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return self_signed_jwt_kwargs

def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down Expand Up @@ -209,14 +210,14 @@ def create_channel(
and ``credentials_file`` are passed.
"""

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
default_scopes=cls.AUTH_SCOPES,
scopes=scopes,
default_host=cls.DEFAULT_HOST,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def create_channel(
aio.Channel: A gRPC AsyncIO channel object.
"""

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers_async.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
default_scopes=cls.AUTH_SCOPES,
scopes=scopes,
default_host=cls.DEFAULT_HOST,
**kwargs,
)

Expand Down Expand Up @@ -200,6 +200,7 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# 'Development Status :: 5 - Production/Stable'
release_status = "Development Status :: 5 - Production/Stable"
dependencies = (
"google-api-core[grpc] >= 1.22.2, < 2.0.0dev",
"google-api-core[grpc] >= 1.26.0, <2.0.0dev",
"proto-plus >= 1.15.0",
"packaging >= 14.3",
)
Expand Down
2 changes: 1 addition & 1 deletion testing/constraints-3.6.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
# Then this file should have foo==1.14.0
google-api-core==1.22.2
google-api-core==1.26.0
proto-plus==1.15.0
libcst==0.2.5
packaging==14.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@
from google.cloud.bigquery_datatransfer_v1.services.data_transfer_service import (
transports,
)
from google.cloud.bigquery_datatransfer_v1.services.data_transfer_service.transports.base import (
_API_CORE_VERSION,
)
from google.cloud.bigquery_datatransfer_v1.services.data_transfer_service.transports.base import (
_GOOGLE_AUTH_VERSION,
)
Expand All @@ -58,8 +55,9 @@
import google.auth


# TODO(busunkim): Once google-api-core >= 1.26.0 is required:
# - Delete all the api-core and auth "less than" test cases
# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively
# through google-api-core:
# - Delete the auth "less than" test cases
# - Delete these pytest markers (Make the "greater than or equal to" tests the default).
requires_google_auth_lt_1_25_0 = pytest.mark.skipif(
packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"),
Expand All @@ -70,16 +68,6 @@
reason="This test requires google-auth >= 1.25.0",
)

requires_api_core_lt_1_26_0 = pytest.mark.skipif(
packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"),
reason="This test requires google-api-core < 1.26.0",
)

requires_api_core_gte_1_26_0 = pytest.mark.skipif(
packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"),
reason="This test requires google-api-core >= 1.26.0",
)


def client_cert_source_callback():
return b"cert bytes", b"key bytes"
Expand Down Expand Up @@ -143,6 +131,18 @@ def test_data_transfer_service_client_from_service_account_info(client_class):
assert client.transport._host == "bigquerydatatransfer.googleapis.com:443"


@pytest.mark.parametrize(
"client_class", [DataTransferServiceClient, DataTransferServiceAsyncClient,]
)
def test_data_transfer_service_client_service_account_always_use_jwt(client_class):
with mock.patch.object(
service_account.Credentials, "with_always_use_jwt_access", create=True
) as use_jwt:
creds = service_account.Credentials(None, None, None)
client = client_class(credentials=creds)
use_jwt.assert_called_with(True)


@pytest.mark.parametrize(
"client_class", [DataTransferServiceClient, DataTransferServiceAsyncClient,]
)
Expand Down Expand Up @@ -2776,8 +2776,12 @@ def test_schedule_transfer_runs_flattened():
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
assert args[0].parent == "parent_value"
# assert args[0].start_time == timestamp_pb2.Timestamp(seconds=751)
# assert args[0].end_time == timestamp_pb2.Timestamp(seconds=751)
assert TimestampRule().to_proto(args[0].start_time) == timestamp_pb2.Timestamp(
seconds=751
)
assert TimestampRule().to_proto(args[0].end_time) == timestamp_pb2.Timestamp(
seconds=751
)


def test_schedule_transfer_runs_flattened_error():
Expand Down Expand Up @@ -2825,8 +2829,12 @@ async def test_schedule_transfer_runs_flattened_async():
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]
assert args[0].parent == "parent_value"
# assert args[0].start_time == timestamp_pb2.Timestamp(seconds=751)
# assert args[0].end_time == timestamp_pb2.Timestamp(seconds=751)
assert TimestampRule().to_proto(args[0].start_time) == timestamp_pb2.Timestamp(
seconds=751
)
assert TimestampRule().to_proto(args[0].end_time) == timestamp_pb2.Timestamp(
seconds=751
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -4684,7 +4692,6 @@ def test_data_transfer_service_transport_auth_adc_old_google_auth(transport_clas
(transports.DataTransferServiceGrpcAsyncIOTransport, grpc_helpers_async),
],
)
@requires_api_core_gte_1_26_0
def test_data_transfer_service_transport_create_channel(transport_class, grpc_helpers):
# If credentials and host are not provided, the transport class should use
# ADC credentials.
Expand Down Expand Up @@ -4713,79 +4720,6 @@ def test_data_transfer_service_transport_create_channel(transport_class, grpc_he
)


@pytest.mark.parametrize(
"transport_class,grpc_helpers",
[
(transports.DataTransferServiceGrpcTransport, grpc_helpers),
(transports.DataTransferServiceGrpcAsyncIOTransport, grpc_helpers_async),
],
)
@requires_api_core_lt_1_26_0
def test_data_transfer_service_transport_create_channel_old_api_core(
transport_class, grpc_helpers
):
# If credentials and host are not provided, the transport class should use
# ADC credentials.
with mock.patch.object(
google.auth, "default", autospec=True
) as adc, mock.patch.object(
grpc_helpers, "create_channel", autospec=True
) as create_channel:
creds = ga_credentials.AnonymousCredentials()
adc.return_value = (creds, None)
transport_class(quota_project_id="octopus")

create_channel.assert_called_with(
"bigquerydatatransfer.googleapis.com:443",
credentials=creds,
credentials_file=None,
quota_project_id="octopus",
scopes=("https://www.googleapis.com/auth/cloud-platform",),
ssl_credentials=None,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)


@pytest.mark.parametrize(
"transport_class,grpc_helpers",
[
(transports.DataTransferServiceGrpcTransport, grpc_helpers),
(transports.DataTransferServiceGrpcAsyncIOTransport, grpc_helpers_async),
],
)
@requires_api_core_lt_1_26_0
def test_data_transfer_service_transport_create_channel_user_scopes(
transport_class, grpc_helpers
):
# If credentials and host are not provided, the transport class should use
# ADC credentials.
with mock.patch.object(
google.auth, "default", autospec=True
) as adc, mock.patch.object(
grpc_helpers, "create_channel", autospec=True
) as create_channel:
creds = ga_credentials.AnonymousCredentials()
adc.return_value = (creds, None)

transport_class(quota_project_id="octopus", scopes=["1", "2"])

create_channel.assert_called_with(
"bigquerydatatransfer.googleapis.com:443",
credentials=creds,
credentials_file=None,
quota_project_id="octopus",
scopes=["1", "2"],
ssl_credentials=None,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
],
)


@pytest.mark.parametrize(
"transport_class",
[
Expand Down