diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 759c9401..847f874e 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import datetime import logging from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING @@ -21,7 +22,11 @@ import aiohttp from cryptography.hazmat.backends import default_backend from cryptography.x509 import load_pem_x509_certificate +from google.auth.credentials import TokenState +from google.auth.transport import requests +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.refresh_utils import _downscope_credentials from google.cloud.sql.connector.version import __version__ as version @@ -212,6 +217,82 @@ async def _get_ephemeral( expiration = token_expiration return ephemeral_cert, expiration + async def get_connection_info( + self, + project: str, + region: str, + instance: str, + keys: asyncio.Future, + enable_iam_auth: bool, + ) -> ConnectionInfo: + """Immediately performs a full refresh operation using the Cloud SQL + Admin API. + + Args: + project (str): The name of the project the Cloud SQL instance is + located in. + region (str): The region the Cloud SQL instance is located in. + instance (str): Name of the Cloud SQL instance. + keys (asyncio.Future): A future to the client's public-private key + pair. + enable_iam_auth (bool): Whether an automatic IAM database + authentication connection is being requested (Postgres and MySQL). + + Returns: + ConnectionInfo: All the information required to connect securely to + the Cloud SQL instance. + Raises: + AutoIAMAuthNotSupported: Database engine does not support automatic + IAM authentication. + """ + priv_key, pub_key = await keys + # before making Cloud SQL Admin API calls, refresh creds if required + if not self._credentials.token_state == TokenState.FRESH: + self._credentials.refresh(requests.Request()) + + metadata_task = asyncio.create_task( + self._get_metadata( + project, + region, + instance, + ) + ) + + ephemeral_task = asyncio.create_task( + self._get_ephemeral( + project, + instance, + pub_key, + enable_iam_auth, + ) + ) + try: + metadata = await metadata_task + # check if automatic IAM database authn is supported for database engine + if enable_iam_auth and not metadata["database_version"].startswith( + ("POSTGRES", "MYSQL") + ): + raise AutoIAMAuthNotSupported( + f"'{metadata['database_version']}' does not support " + "automatic IAM authentication. It is only supported with " + "Cloud SQL Postgres or MySQL instances." + ) + except Exception: + # cancel ephemeral cert task if exception occurs before it is awaited + ephemeral_task.cancel() + raise + + ephemeral_cert, expiration = await ephemeral_task + + return ConnectionInfo( + ephemeral_cert, + metadata["server_ca_cert"], + priv_key, + metadata["ip_addresses"], + metadata["database_version"], + expiration, + ) + async def close(self) -> None: """Close CloudSQLClient gracefully.""" await self._client.close() diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py new file mode 100644 index 00000000..0f1f1638 --- /dev/null +++ b/google/cloud/sql/connector/connection_info.py @@ -0,0 +1,105 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +import logging +import ssl +from tempfile import TemporaryDirectory +from typing import Any, Dict, Optional, TYPE_CHECKING + +from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError +from google.cloud.sql.connector.exceptions import TLSVersionError +from google.cloud.sql.connector.utils import write_to_file + +if TYPE_CHECKING: + import datetime + + from google.cloud.sql.connector.instance import IPTypes + +logger = logging.getLogger(name=__name__) + + +@dataclass +class ConnectionInfo: + """Contains all necessary information to connect securely to the + server-side Proxy running on a Cloud SQL instance.""" + + client_cert: str + server_ca_cert: str + private_key: bytes + ip_addrs: Dict[str, Any] + database_version: str + expiration: datetime.datetime + context: Optional[ssl.SSLContext] = None + + def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: + """Constructs a SSL/TLS context for the given connection info. + + Cache the SSL context to ensure we don't read from disk repeatedly when + configuring a secure connection. + """ + # if SSL context is cached, use it + if self.context is not None: + return self.context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # update ssl.PROTOCOL_TLS_CLIENT default + context.check_hostname = False + + # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been + # implemented. The ssl module requires OpenSSL 1.1.1 or newer. + # verify OpenSSL version supports TLSv1.3 + if ssl.HAS_TLSv1_3: + # force TLSv1.3 if supported by client + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # fallback to TLSv1.2 for older versions of OpenSSL + else: + if enable_iam_auth: + raise TLSVersionError( + f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not " + "support TLSv1.3, which is required to use IAM Authentication.\n" + "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." + ) + logger.warning( + "TLSv1.3 is not supported with your version of OpenSSL " + f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n" + "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." + ) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + # tmpdir and its contents are automatically deleted after the CA cert + # and ephemeral cert are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + with TemporaryDirectory() as tmpdir: + ca_filename, cert_filename, key_filename = write_to_file( + tmpdir, self.server_ca_cert, self.client_cert, self.private_key + ) + context.load_cert_chain(cert_filename, keyfile=key_filename) + context.load_verify_locations(cafile=ca_filename) + # set class attribute to cache context for subsequent calls + self.context = context + return context + + def get_preferred_ip(self, ip_type: IPTypes) -> str: + """Returns the first IP address for the instance, according to the preference + supplied by ip_type. If no IP addressess with the given preference are found, + an error is raised.""" + if ip_type.value in self.ip_addrs: + return self.ip_addrs[ip_type.value] + raise CloudSQLIPTypeError( + "Cloud SQL instance does not have any IP addresses matching " + f"preference: {ip_type.value})" + ) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index b8b509ea..3a49f02e 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -17,30 +17,19 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass from enum import Enum import logging import re -import ssl -from tempfile import TemporaryDirectory -from typing import Any, Dict, Tuple, TYPE_CHECKING +from typing import Tuple import aiohttp -from google.auth.credentials import TokenState -from google.auth.transport import requests from google.cloud.sql.connector.client import CloudSQLClient -from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported -from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError +from google.cloud.sql.connector.connection_info import ConnectionInfo from google.cloud.sql.connector.exceptions import RefreshNotValidError -from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh -from google.cloud.sql.connector.utils import write_to_file - -if TYPE_CHECKING: - import datetime logger = logging.getLogger(name=__name__) @@ -83,79 +72,6 @@ def _from_str(cls, ip_type_str: str) -> IPTypes: return cls(ip_type_str.upper()) -@dataclass -class ConnectionInfo: - """Contains all necessary information to connect securely to the - server-side Proxy running on a Cloud SQL instance.""" - - client_cert: str - server_ca_cert: str - private_key: bytes - ip_addrs: Dict[str, Any] - database_version: str - expiration: datetime.datetime - context: ssl.SSLContext | None = None - - def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: - """Constructs a SSL/TLS context for the given connection info. - - Cache the SSL context to ensure we don't read from disk repeatedly when - configuring a secure connection. - """ - # if SSL context is cached, use it - if self.context is not None: - return self.context - context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - - # update ssl.PROTOCOL_TLS_CLIENT default - context.check_hostname = False - - # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been - # implemented. The ssl module requires OpenSSL 1.1.1 or newer. - # verify OpenSSL version supports TLSv1.3 - if ssl.HAS_TLSv1_3: - # force TLSv1.3 if supported by client - context.minimum_version = ssl.TLSVersion.TLSv1_3 - # fallback to TLSv1.2 for older versions of OpenSSL - else: - if enable_iam_auth: - raise TLSVersionError( - f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not " - "support TLSv1.3, which is required to use IAM Authentication.\n" - "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." - ) - logger.warning( - "TLSv1.3 is not supported with your version of OpenSSL " - f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n" - "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." - ) - context.minimum_version = ssl.TLSVersion.TLSv1_2 - - # tmpdir and its contents are automatically deleted after the CA cert - # and ephemeral cert are loaded into the SSLcontext. The values - # need to be written to files in order to be loaded by the SSLContext - with TemporaryDirectory() as tmpdir: - ca_filename, cert_filename, key_filename = write_to_file( - tmpdir, self.server_ca_cert, self.client_cert, self.private_key - ) - context.load_cert_chain(cert_filename, keyfile=key_filename) - context.load_verify_locations(cafile=ca_filename) - # set class attribute to cache context for subsequent calls - self.context = context - return context - - def get_preferred_ip(self, ip_type: IPTypes) -> str: - """Returns the first IP address for the instance, according to the preference - supplied by ip_type. If no IP addressess with the given preference are found, - an error is raised.""" - if ip_type.value in self.ip_addrs: - return self.ip_addrs[ip_type.value] - raise CloudSQLIPTypeError( - "Cloud SQL instance does not have any IP addresses matching " - f"preference: {ip_type.value})" - ) - - class RefreshAheadCache: """Cache that refreshes connection info in the background prior to expiration. @@ -229,45 +145,13 @@ async def _perform_refresh(self) -> ConnectionInfo: try: await self._refresh_rate_limiter.acquire() - priv_key, pub_key = await self._keys - - logger.debug(f"['{self._instance_connection_string}']: Creating context") - - # before making Cloud SQL Admin API calls, refresh creds - if not self._client._credentials.token_state == TokenState.FRESH: - self._client._credentials.refresh(requests.Request()) - - metadata_task = asyncio.create_task( - self._client._get_metadata( - self._project, - self._region, - self._instance, - ) - ) - - ephemeral_task = asyncio.create_task( - self._client._get_ephemeral( - self._project, - self._instance, - pub_key, - self._enable_iam_auth, - ) + connection_info = await self._client.get_connection_info( + self._project, + self._region, + self._instance, + self._keys, + self._enable_iam_auth, ) - try: - metadata = await metadata_task - # check if automatic IAM database authn is supported for database engine - if self._enable_iam_auth and not metadata[ - "database_version" - ].startswith(("POSTGRES", "MYSQL")): - raise AutoIAMAuthNotSupported( - f"'{metadata['database_version']}' does not support automatic IAM authentication. It is only supported with Cloud SQL Postgres or MySQL instances." - ) - except Exception: - # cancel ephemeral cert task if exception occurs before it is awaited - ephemeral_task.cancel() - raise - - ephemeral_cert, expiration = await ephemeral_task except aiohttp.ClientResponseError as e: logger.debug( @@ -285,15 +169,7 @@ async def _perform_refresh(self) -> ConnectionInfo: finally: self._refresh_in_progress.clear() - - return ConnectionInfo( - ephemeral_cert, - metadata["server_ca_cert"], - priv_key, - metadata["ip_addresses"], - metadata["database_version"], - expiration, - ) + return connection_info def _schedule_refresh(self, delay: int) -> asyncio.Task: """ diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index ef2ea96c..bd3b3a8c 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -27,10 +27,10 @@ import pytest # noqa F401 Needed to run the tests from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_info import ConnectionInfo from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.instance import _parse_instance_connection_name -from google.cloud.sql.connector.instance import ConnectionInfo from google.cloud.sql.connector.instance import IPTypes from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter