Skip to content

Commit

Permalink
feat: add support for CAS-based instances
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Sep 3, 2024
1 parent bb80425 commit 2c549a3
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 39 deletions.
5 changes: 3 additions & 2 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont
return self.context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# update ssl.PROTOCOL_TLS_CLIENT default
context.check_hostname = False
if self.server_ca_mode != "GOOGLE_MANAGED_CAS_CA":
# update ssl.PROTOCOL_TLS_CLIENT default
context.check_hostname = False

# TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
# implemented. The ssl module requires OpenSSL 1.1.1 or newer.
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,15 @@ async def connect_async(
return await connector(
ip_address,
await conn_info.create_ssl_context(enable_iam_auth),
conn_info.dns_name,
**kwargs,
)
# synchronous drivers are blocking and run using executor
connect_partial = partial(
connector,
ip_address,
await conn_info.create_ssl_context(enable_iam_auth),
conn_info.dns_name,
**kwargs,
)
return await self._loop.run_in_executor(None, connect_partial)
Expand Down
28 changes: 17 additions & 11 deletions google/cloud/sql/connector/pg8000.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import socket
import ssl
from typing import Any, TYPE_CHECKING
Expand All @@ -24,33 +25,38 @@


def connect(
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
ip_address: str, ctx: ssl.SSLContext, server_name: str, **kwargs: Any
) -> "pg8000.dbapi.Connection":
"""Helper function to create a pg8000 DB-API connection object.
:type ip_address: str
:param ip_address: A string containing an IP address for the Cloud SQL
instance.
Args:
ip_address (str): The IP address for the Cloud SQL instance.
:type ctx: ssl.SSLContext
:param ctx: An SSLContext object created from the Cloud SQL server CA
cert and ephemeral cert.
ctx (ssl.SSLContext): An SSL/TLS object created from the Cloud SQL
server CA cert and ephemeral cert.
server_name (str): The server name of the Cloud SQL instance. Used to
verify the server identity for CAS instances.
:rtype: pg8000.dbapi.Connection
:returns: A pg8000 Connection object for the Cloud SQL instance.
Returns:
(pg8000.dbapi.Connection) A pg8000 connection object to the Cloud SQL
instance.
"""
try:
import pg8000
except ImportError:
raise ImportError(
'Unable to import module "pg8000." Please install and try again.'
)

# if CAS instance, check server name
if ctx.check_hostname:
server_name = server_name
else:
server_name = None
# Create socket and wrap with context.
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
server_hostname=server_name,
)

user = kwargs.pop("user")
Expand Down
31 changes: 19 additions & 12 deletions google/cloud/sql/connector/pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import socket
import ssl
from typing import Any, TYPE_CHECKING
Expand All @@ -24,20 +25,22 @@


def connect(
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
) -> "pymysql.connections.Connection":
ip_address: str, ctx: ssl.SSLContext, server_name: str, **kwargs: Any
) -> "pymysql.Connection":
"""Helper function to create a pymysql DB-API connection object.
:type ip_address: str
:param ip_address: A string containing an IP address for the Cloud SQL
instance.
Args:
ip_address (str): The IP address for the Cloud SQL instance.
ctx (ssl.SSLContext): An SSL/TLS object created from the Cloud SQL
server CA cert and ephemeral cert.
:type ctx: ssl.SSLContext
:param ctx: An SSLContext object created from the Cloud SQL server CA
cert and ephemeral cert.
server_name (str): The server name of the Cloud SQL instance. Used to
verify the server identity for CAS instances.
:rtype: pymysql.Connection
:returns: A PyMySQL Connection object for the Cloud SQL instance.
Returns:
(pymysql.Connection) A pymysql connection object to the Cloud SQL
instance.
"""
try:
import pymysql
Expand All @@ -48,11 +51,15 @@ def connect(

# allow automatic IAM database authentication to not require password
kwargs["password"] = kwargs["password"] if "password" in kwargs else None

# if CAS instance, check server name
if ctx.check_hostname:
server_name = server_name
else:
server_name = None
# Create socket and wrap with context.
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
server_hostname=server_name,
)
# pop timeout as timeout arg is called 'connect_timeout' for pymysql
timeout = kwargs.pop("timeout")
Expand Down
30 changes: 19 additions & 11 deletions google/cloud/sql/connector/pytds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import platform
import socket
import ssl
Expand All @@ -26,20 +27,23 @@
import pytds


def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Connection":
def connect(
ip_address: str, ctx: ssl.SSLContext, server_name: str, **kwargs: Any
) -> "pytds.Connection":
"""Helper function to create a pytds DB-API connection object.
:type ip_address: str
:param ip_address: A string containing an IP address for the Cloud SQL
instance.
Args:
ip_address (str): The IP address for the Cloud SQL instance.
:type ctx: ssl.SSLContext
:param ctx: An SSLContext object created from the Cloud SQL server CA
cert and ephemeral cert.
ctx (ssl.SSLContext): An SSL/TLS object created from the Cloud SQL
server CA cert and ephemeral cert.
server_name (str): The server name of the Cloud SQL instance. Used to
verify the server identity for CAS instances.
:rtype: pytds.Connection
:returns: A pytds Connection object for the Cloud SQL instance.
Returns:
(pytds.Connection) A pytds connection object to the Cloud SQL
instance.
"""
try:
import pytds
Expand All @@ -49,11 +53,15 @@ def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Conne
)

db = kwargs.pop("db", None)

# if CAS instance, check server name
if ctx.check_hostname:
server_name = server_name
else:
server_name = None
# Create socket and wrap with context.
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
server_hostname=server_name,
)
if kwargs.pop("active_directory_auth", False):
if platform.system() == "Windows":
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/sql/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Tuple
from typing import List, Tuple

import aiofiles
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -60,7 +60,7 @@ async def generate_keys() -> Tuple[bytes, str]:


async def write_to_file(
dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes
dir_path: str, serverCaCert: List[str], ephemeralCert: str, priv_key: bytes
) -> Tuple[str, str, str]:
"""
Helper function to write the serverCaCert, ephemeral certificate and
Expand All @@ -71,7 +71,7 @@ async def write_to_file(
key_filename = f"{dir_path}/priv.pem"

async with aiofiles.open(ca_filename, "w+") as ca_out:
await ca_out.write(serverCaCert)
await ca_out.write("".join(serverCaCert))
async with aiofiles.open(cert_filename, "w+") as ephemeral_out:
await ephemeral_out.write(ephemeralCert)
async with aiofiles.open(key_filename, "wb") as priv_out:
Expand Down

0 comments on commit 2c549a3

Please sign in to comment.