Skip to content

Commit

Permalink
feat: add SslCredentials class for mTLS ADC (#448)
Browse files Browse the repository at this point in the history
feat: add SslCredentials class for mTLS ADC (linux)
  • Loading branch information
arithmetic1728 authored Mar 19, 2020
1 parent a6d6329 commit dafb41f
Show file tree
Hide file tree
Showing 5 changed files with 723 additions and 56 deletions.
116 changes: 116 additions & 0 deletions google/auth/transport/_mtls_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions for getting mTLS cert and key, for internal use only."""

import json
import logging
from os import path
import re
import subprocess

CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
_CERT_PROVIDER_COMMAND = "cert_provider_command"
_CERT_REGEX = re.compile(
b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL
)

# support various format of key files, e.g.
# "-----BEGIN PRIVATE KEY-----...",
# "-----BEGIN EC PRIVATE KEY-----...",
# "-----BEGIN RSA PRIVATE KEY-----..."
_KEY_REGEX = re.compile(
b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?",
re.DOTALL,
)

_LOGGER = logging.getLogger(__name__)


def _check_dca_metadata_path(metadata_path):
"""Checks for context aware metadata. If it exists, returns the absolute path;
otherwise returns None.
Args:
metadata_path (str): context aware metadata path.
Returns:
str: absolute path if exists and None otherwise.
"""
metadata_path = path.expanduser(metadata_path)
if not path.exists(metadata_path):
_LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path)
return None
return metadata_path


def _read_dca_metadata_file(metadata_path):
"""Loads context aware metadata from the given path.
Args:
metadata_path (str): context aware metadata path.
Returns:
Dict[str, str]: The metadata.
Raises:
ValueError: If failed to parse metadata as JSON.
"""
with open(metadata_path) as f:
metadata = json.load(f)

return metadata


def get_client_ssl_credentials(metadata_json):
"""Returns the client side mTLS cert and key.
Args:
metadata_json (Dict[str, str]): metadata JSON file which contains the cert
provider command.
Returns:
Tuple[bytes, bytes]: client certificate and key, both in PEM format.
Raises:
OSError: If the cert provider command failed to run.
RuntimeError: If the cert provider command has a runtime error.
ValueError: If the metadata json file doesn't contain the cert provider command or if the command doesn't produce both the client certificate and client key.
"""
# TODO: implement an in-memory cache of cert and key so we don't have to
# run cert provider command every time.

# Check the cert provider command existence in the metadata json file.
if _CERT_PROVIDER_COMMAND not in metadata_json:
raise ValueError("Cert provider command is not found")

# Execute the command. It throws OsError in case of system failure.
command = metadata_json[_CERT_PROVIDER_COMMAND]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()

# Check cert provider command execution error.
if process.returncode != 0:
raise RuntimeError(
"Cert provider command returns non-zero status code %s" % process.returncode
)

# Extract certificate (chain) and key.
cert_match = re.findall(_CERT_REGEX, stdout)
if len(cert_match) != 1:
raise ValueError("Client SSL certificate is missing or invalid")
key_match = re.findall(_KEY_REGEX, stdout)
if len(key_match) != 1:
raise ValueError("Client SSL key is missing or invalid")
return cert_match[0], key_match[0]
188 changes: 184 additions & 4 deletions google/auth/transport/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
from __future__ import absolute_import

from concurrent import futures
import logging

import six

from google.auth.transport import _mtls_helper

try:
import grpc
except ImportError as caught_exc: # pragma: NO COVER
Expand All @@ -31,6 +34,8 @@
caught_exc,
)

_LOGGER = logging.getLogger(__name__)


class AuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each
Expand Down Expand Up @@ -92,7 +97,12 @@ def __del__(self):


def secure_authorized_channel(
credentials, request, target, ssl_credentials=None, **kwargs
credentials,
request,
target,
ssl_credentials=None,
client_cert_callback=None,
**kwargs
):
"""Creates a secure authorized gRPC channel.
Expand All @@ -114,11 +124,86 @@ def secure_authorized_channel(
# Create a channel.
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, 'speech.googleapis.com:443', request)
credentials, regular_endpoint, request,
ssl_credentials=grpc.ssl_channel_credentials())
# Use the channel to create a stub.
cloud_speech.create_Speech_stub(channel)
Usage:
There are actually a couple of options to create a channel, depending on if
you want to create a regular or mutual TLS channel.
First let's list the endpoints (regular vs mutual TLS) to choose from::
regular_endpoint = 'speech.googleapis.com:443'
mtls_endpoint = 'speech.mtls.googleapis.com:443'
Option 1: create a regular (non-mutual) TLS channel by explicitly setting
the ssl_credentials::
regular_ssl_credentials = grpc.ssl_channel_credentials()
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, regular_endpoint, request,
ssl_credentials=regular_ssl_credentials)
Option 2: create a mutual TLS channel by calling a callback which returns
the client side certificate and the key::
def my_client_cert_callback():
code_to_load_client_cert_and_key()
if loaded:
return (pem_cert_bytes, pem_key_bytes)
raise MyClientCertFailureException()
try:
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, mtls_endpoint, request,
client_cert_callback=my_client_cert_callback)
except MyClientCertFailureException:
# handle the exception
Option 3: use application default SSL credentials. It searches and uses
the command in a context aware metadata file, which is available on devices
with endpoint verification support.
See https://cloud.google.com/endpoint-verification/docs/overview::
try:
default_ssl_credentials = SslCredentials()
except:
# Exception can be raised if the context aware metadata is malformed.
# See :class:`SslCredentials` for the possible exceptions.
# Choose the endpoint based on the SSL credentials type.
if default_ssl_credentials.is_mtls:
endpoint_to_use = mtls_endpoint
else:
endpoint_to_use = regular_endpoint
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, endpoint_to_use, request,
ssl_credentials=default_ssl_credentials)
Option 4: not setting ssl_credentials and client_cert_callback. For devices
without endpoint verification support, a regular TLS channel is created;
otherwise, a mutual TLS channel is created, however, the call should be
wrapped in a try/except block in case of malformed context aware metadata.
The following code uses regular_endpoint, it works the same no matter the
created channle is regular or mutual TLS. Regular endpoint ignores client
certificate and key::
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, regular_endpoint, request)
The following code uses mtls_endpoint, if the created channle is regular,
and API mtls_endpoint is confgured to require client SSL credentials, API
calls using this channel will be rejected::
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, mtls_endpoint, request)
Args:
credentials (google.auth.credentials.Credentials): The credentials to
add to requests.
Expand All @@ -129,23 +214,118 @@ def secure_authorized_channel(
target (str): The host and port of the service.
ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
credentials. This can be used to specify different certificates.
This argument is mutually exclusive with client_cert_callback;
providing both will raise an exception.
If ssl_credentials and client_cert_callback are None, application
default SSL credentials will be used.
client_cert_callback (Callable[[], (bytes, bytes)]): Optional
callback function to obtain client certicate and key for mutual TLS
connection. This argument is mutually exclusive with
ssl_credentials; providing both will raise an exception.
If ssl_credentials and client_cert_callback are None, application
default SSL credentials will be used.
kwargs: Additional arguments to pass to :func:`grpc.secure_channel`.
Returns:
grpc.Channel: The created gRPC channel.
Raises:
OSError: If the cert provider command launch fails during the application
default SSL credentials loading process on devices with endpoint
verification support.
RuntimeError: If the cert provider command has a runtime error during the
application default SSL credentials loading process on devices with
endpoint verification support.
ValueError:
If the context aware metadata file is malformed or if the cert provider
command doesn't produce both client certificate and key during the
application default SSL credentials loading process on devices with
endpoint verification support.
"""
# Create the metadata plugin for inserting the authorization header.
metadata_plugin = AuthMetadataPlugin(credentials, request)

# Create a set of grpc.CallCredentials using the metadata plugin.
google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)

if ssl_credentials is None:
ssl_credentials = grpc.ssl_channel_credentials()
if ssl_credentials and client_cert_callback:
raise ValueError(
"Received both ssl_credentials and client_cert_callback; "
"these are mutually exclusive."
)

# If SSL credentials are not explicitly set, try client_cert_callback and ADC.
if not ssl_credentials:
if client_cert_callback:
# Use the callback if provided.
cert, key = client_cert_callback()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
# Use application default SSL credentials.
adc_ssl_credentils = SslCredentials()
ssl_credentials = adc_ssl_credentils.ssl_credentials

# Combine the ssl credentials and the authorization credentials.
composite_credentials = grpc.composite_channel_credentials(
ssl_credentials, google_auth_credentials
)

return grpc.secure_channel(target, composite_credentials, **kwargs)


class SslCredentials:
"""Class for application default SSL credentials.
For devices with endpoint verification support, a device certificate will be
automatically loaded and mutual TLS will be established.
See https://cloud.google.com/endpoint-verification/docs/overview.
"""

def __init__(self):
# Load client SSL credentials.
self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path(
_mtls_helper.CONTEXT_AWARE_METADATA_PATH
)
if self._context_aware_metadata_path:
self._is_mtls = True
else:
self._is_mtls = False

@property
def ssl_credentials(self):
"""Get the created SSL channel credentials.
For devices with endpoint verification support, if the device certificate
loading has any problems, corresponding exceptions will be raised. For
a device without endpoint verification support, no exceptions will be
raised.
Returns:
grpc.ChannelCredentials: The created grpc channel credentials.
Raises:
OSError: If the cert provider command launch fails.
RuntimeError: If the cert provider command has a runtime error.
ValueError:
If the context aware metadata file is malformed or if the cert provider
command doesn't produce both the client certificate and key.
"""
if self._context_aware_metadata_path:
metadata = _mtls_helper._read_dca_metadata_file(
self._context_aware_metadata_path
)
cert, key = _mtls_helper.get_client_ssl_credentials(metadata)
self._ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
self._ssl_credentials = grpc.ssl_channel_credentials()

return self._ssl_credentials

@property
def is_mtls(self):
"""Indicates if the created SSL channel credentials is mutual TLS."""
return self._is_mtls
6 changes: 6 additions & 0 deletions tests/data/context_aware_metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cert_provider_command":[
"/opt/google/endpoint-verification/bin/SecureConnectHelper",
"--print_certificate"],
"device_resource_ids":["11111111-1111-1111"]
}
Loading

0 comments on commit dafb41f

Please sign in to comment.