Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mTLS to ads template #384

Merged
merged 3 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

{% block content %}
from collections import OrderedDict
from typing import Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
import re
from typing import Callable, Dict, {% if service.any_server_streaming %}Iterable, {% endif %}{% if service.any_client_streaming %}Iterator, {% endif %}Sequence, Tuple, Type, Union
import pkg_resources

import google.api_core.client_options as ClientOptions # type: ignore
Expand Down Expand Up @@ -57,7 +58,39 @@ class {{ service.client_name }}Meta(type):
class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
"""{{ service.meta.doc|rst(width=72, indent=4) }}"""

DEFAULT_OPTIONS = ClientOptions.ClientOptions({% if service.host %}api_endpoint='{{ service.host }}'{% endif %})
@staticmethod
def _get_default_mtls_endpoint(api_endpoint):
"""Convert api endpoint to mTLS endpoint.
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
Args:
api_endpoint (Optional[str]): the api endpoint to convert.
Returns:
str: converted mTLS api endpoint.
"""
if not api_endpoint:
return api_endpoint

mtls_endpoint_re = re.compile(
r"(?P<name>[^.]+)(?P<mtls>\.mtls)?(?P<sandbox>\.sandbox)?(?P<googledomain>\.googleapis\.com)?"
)

m = mtls_endpoint_re.match(api_endpoint)
name, mtls, sandbox, googledomain = m.groups()
if mtls or not googledomain:
return api_endpoint

if sandbox:
return api_endpoint.replace(
"sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
)

return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")

DEFAULT_ENDPOINT = {% if service.host %}'{{ service.host }}'{% else %}None{% endif %}
DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
DEFAULT_ENDPOINT
)

@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
Expand Down Expand Up @@ -92,7 +125,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
def __init__(self, *,
credentials: credentials.Credentials = None,
transport: Union[str, {{ service.name }}Transport] = None,
client_options: ClientOptions = DEFAULT_OPTIONS,
client_options: ClientOptions = None,
) -> None:
"""Instantiate the {{ (service.client_name|snake_case).replace('_', ' ') }}.

Expand All @@ -106,6 +139,17 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
transport to use. If set to None, a transport is chosen
automatically.
client_options (ClientOptions): Custom options for the client.
(1) The ``api_endpoint`` property can be used to override the
default endpoint provided by the client.
(2) If ``transport`` argument is None, ``client_options`` can be
used to create a mutual TLS transport. If ``client_cert_source``
is provided, mutual TLS transport will be created with the given
``api_endpoint`` or the default mTLS endpoint, and the client
SSL credentials obtained from ``client_cert_source``.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
if isinstance(client_options, dict):
client_options = ClientOptions.from_dict(client_options)
Expand All @@ -114,16 +158,45 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
if isinstance(transport, {{ service.name }}Transport):
# transport is a {{ service.name }}Transport instance.
if credentials:
raise ValueError('When providing a transport instance, '
'provide its credentials directly.')
self._transport = transport
else:
elif client_options is None or (
client_options.api_endpoint == None
and client_options.client_cert_source is None
):
# Don't trigger mTLS if we get an empty ClientOptions.
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials,
host=client_options.api_endpoint{% if service.host %} or '{{ service.host }}'{% endif %},
credentials=credentials, host=self.DEFAULT_ENDPOINT
)
else:
# We have a non-empty ClientOptions. If client_cert_source is
# provided, trigger mTLS with user provided endpoint or the default
# mTLS endpoint.
if client_options.client_cert_source:
api_mtls_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_MTLS_ENDPOINT
)
else:
api_mtls_endpoint = None

api_endpoint = (
client_options.api_endpoint
if client_options.api_endpoint
else self.DEFAULT_ENDPOINT
)

self._transport = {{ service.name }}GrpcTransport(
credentials=credentials,
host=api_endpoint,
api_mtls_endpoint=api_mtls_endpoint,
client_cert_source=client_options.client_cert_source,
)

{% for method in service.methods.values() -%}
def {{ method.name|snake_case }}(self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{% extends '_base.py.j2' %}

{% block content %}
from typing import Callable, Dict
from typing import Callable, Dict, Tuple

from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
{%- endif %}
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore


import grpc # type: ignore

Expand Down Expand Up @@ -35,7 +37,9 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: credentials.Credentials = None,
channel: grpc.Channel = None) -> None:
channel: grpc.Channel = None,
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None) -> None:
"""Instantiate the transport.

Args:
Expand All @@ -49,19 +53,51 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
This argument is ignored if ``channel`` is provided.
channel (Optional[grpc.Channel]): A ``Channel`` instance through
which to make calls.
api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If
provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
``client_cert_source`` or applicatin default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A
callback to provide client SSL certificate bytes and private key
bytes, both in PEM format. It is ignored if ``api_mtls_endpoint``
is None.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
creation failed for any reason.
"""
# Sanity check: Ensure that channel and credentials are not both
# provided.
if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
credentials = False

# If a channel was explicitly provided, set it.
self._grpc_channel = channel
elif api_mtls_endpoint:
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
if client_cert_source:
cert, key = client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
ssl_credentials = SslCredentials().ssl_credentials

# create a new channel. The provided one is ignored.
self._grpc_channel = grpc_helpers.create_channel(
host,
credentials=credentials,
ssl_credentials=ssl_credentials,
scopes=self.AUTH_SCOPES,
)

# Run the base constructor.
super().__init__(host=host, credentials=credentials)
self._stubs = {} # type: Dict[str, Callable]
self._stubs = {} # type: Dict[str, Callable]

# If a channel was explicitly provided, set it.
if channel:
self._grpc_channel = channel

@classmethod
def create_channel(cls,
Expand Down
Loading