diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 3ba460c2..ddd56a33 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -61,8 +61,9 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont return self.context context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - # update ssl.PROTOCOL_TLS_CLIENT default - context.check_hostname = False + if self.server_ca_mode != "GOOGLE_MANAGED_CAS_CA": + # update ssl.PROTOCOL_TLS_CLIENT default + context.check_hostname = False # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been # implemented. The ssl module requires OpenSSL 1.1.1 or newer. diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 20235362..0f753f7c 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -366,6 +366,7 @@ async def connect_async( return await connector( ip_address, await conn_info.create_ssl_context(enable_iam_auth), + conn_info.dns_name, **kwargs, ) # synchronous drivers are blocking and run using executor @@ -373,6 +374,7 @@ async def connect_async( connector, ip_address, await conn_info.create_ssl_context(enable_iam_auth), + conn_info.dns_name, **kwargs, ) return await self._loop.run_in_executor(None, connect_partial) diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 623738f8..3af5b5c6 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import socket import ssl from typing import Any, TYPE_CHECKING @@ -24,21 +25,22 @@ def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, ctx: ssl.SSLContext, server_name: str, **kwargs: Any ) -> "pg8000.dbapi.Connection": """Helper function to create a pg8000 DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. + Args: + ip_address (str): The IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA - cert and ephemeral cert. + ctx (ssl.SSLContext): An SSL/TLS object created from the Cloud SQL + server CA cert and ephemeral cert. + server_name (str): The server name of the Cloud SQL instance. Used to + verify the server identity for CAS instances. - :rtype: pg8000.dbapi.Connection - :returns: A pg8000 Connection object for the Cloud SQL instance. + Returns: + (pg8000.dbapi.Connection) A pg8000 connection object to the Cloud SQL + instance. """ try: import pg8000 @@ -46,11 +48,15 @@ def connect( raise ImportError( 'Unable to import module "pg8000." Please install and try again.' ) - + # if CAS instance, check server name + if ctx.check_hostname: + server_name = server_name + else: + server_name = None # Create socket and wrap with context. sock = ctx.wrap_socket( socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, + server_hostname=server_name, ) user = kwargs.pop("user") diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index 8971ff9b..94d50be7 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import socket import ssl from typing import Any, TYPE_CHECKING @@ -24,20 +25,22 @@ def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any -) -> "pymysql.connections.Connection": + ip_address: str, ctx: ssl.SSLContext, server_name: str, **kwargs: Any +) -> "pymysql.Connection": """Helper function to create a pymysql DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. + Args: + ip_address (str): The IP address for the Cloud SQL instance. + + ctx (ssl.SSLContext): An SSL/TLS object created from the Cloud SQL + server CA cert and ephemeral cert. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA - cert and ephemeral cert. + server_name (str): The server name of the Cloud SQL instance. Used to + verify the server identity for CAS instances. - :rtype: pymysql.Connection - :returns: A PyMySQL Connection object for the Cloud SQL instance. + Returns: + (pymysql.Connection) A pymysql connection object to the Cloud SQL + instance. """ try: import pymysql @@ -48,11 +51,15 @@ def connect( # allow automatic IAM database authentication to not require password kwargs["password"] = kwargs["password"] if "password" in kwargs else None - + # if CAS instance, check server name + if ctx.check_hostname: + server_name = server_name + else: + server_name = None # Create socket and wrap with context. sock = ctx.wrap_socket( socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, + server_hostname=server_name, ) # pop timeout as timeout arg is called 'connect_timeout' for pymysql timeout = kwargs.pop("timeout") diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 5c78fd3f..cf83f700 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import platform import socket import ssl @@ -26,20 +27,23 @@ import pytds -def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Connection": +def connect( + ip_address: str, ctx: ssl.SSLContext, server_name: str, **kwargs: Any +) -> "pytds.Connection": """Helper function to create a pytds DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. + Args: + ip_address (str): The IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA - cert and ephemeral cert. + ctx (ssl.SSLContext): An SSL/TLS object created from the Cloud SQL + server CA cert and ephemeral cert. + server_name (str): The server name of the Cloud SQL instance. Used to + verify the server identity for CAS instances. - :rtype: pytds.Connection - :returns: A pytds Connection object for the Cloud SQL instance. + Returns: + (pytds.Connection) A pytds connection object to the Cloud SQL + instance. """ try: import pytds @@ -49,11 +53,15 @@ def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Conne ) db = kwargs.pop("db", None) - + # if CAS instance, check server name + if ctx.check_hostname: + server_name = server_name + else: + server_name = None # Create socket and wrap with context. sock = ctx.wrap_socket( socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, + server_hostname=server_name, ) if kwargs.pop("active_directory_auth", False): if platform.system() == "Windows": diff --git a/google/cloud/sql/connector/utils.py b/google/cloud/sql/connector/utils.py index 47a318fb..d26be527 100755 --- a/google/cloud/sql/connector/utils.py +++ b/google/cloud/sql/connector/utils.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Tuple +from typing import List, Tuple import aiofiles from cryptography.hazmat.backends import default_backend @@ -60,7 +60,7 @@ async def generate_keys() -> Tuple[bytes, str]: async def write_to_file( - dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes + dir_path: str, serverCaCert: List[str], ephemeralCert: str, priv_key: bytes ) -> Tuple[str, str, str]: """ Helper function to write the serverCaCert, ephemeral certificate and @@ -71,7 +71,7 @@ async def write_to_file( key_filename = f"{dir_path}/priv.pem" async with aiofiles.open(ca_filename, "w+") as ca_out: - await ca_out.write(serverCaCert) + await ca_out.write("".join(serverCaCert)) async with aiofiles.open(cert_filename, "w+") as ephemeral_out: await ephemeral_out.write(ephemeralCert) async with aiofiles.open(key_filename, "wb") as priv_out: