Skip to content

Commit

Permalink
refactor: move ConnectionInfo to own file and add `get_connection_inf…
Browse files Browse the repository at this point in the history
…o` (#1090)

Refactor revolving around ConnectionInfo in preparation for adding a
LazyRefreshCache.

Moving ConnectionInfo into its own file, connection_info.py as it is
to be shared by both RefreshAheadCache and LazyRefreshCache.

Adding a get_connection_info method to the CloudSQLClient.
This is the equivalent of the Go Connector's refresher.ConnectInfo.

This will allow the lazy refresh to just check expiration and then call
get_connection_info and not need to duplicate code from refresh
ahead cache.
  • Loading branch information
jackwotherspoon authored May 29, 2024
1 parent 9fbe87a commit e52c2d9
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 134 deletions.
81 changes: 81 additions & 0 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

from __future__ import annotations

import asyncio
import datetime
import logging
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING

import aiohttp
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
from google.cloud.sql.connector.version import __version__ as version

Expand Down Expand Up @@ -212,6 +217,82 @@ async def _get_ephemeral(
expiration = token_expiration
return ephemeral_cert, expiration

async def get_connection_info(
self,
project: str,
region: str,
instance: str,
keys: asyncio.Future,
enable_iam_auth: bool,
) -> ConnectionInfo:
"""Immediately performs a full refresh operation using the Cloud SQL
Admin API.
Args:
project (str): The name of the project the Cloud SQL instance is
located in.
region (str): The region the Cloud SQL instance is located in.
instance (str): Name of the Cloud SQL instance.
keys (asyncio.Future): A future to the client's public-private key
pair.
enable_iam_auth (bool): Whether an automatic IAM database
authentication connection is being requested (Postgres and MySQL).
Returns:
ConnectionInfo: All the information required to connect securely to
the Cloud SQL instance.
Raises:
AutoIAMAuthNotSupported: Database engine does not support automatic
IAM authentication.
"""
priv_key, pub_key = await keys
# before making Cloud SQL Admin API calls, refresh creds if required
if not self._credentials.token_state == TokenState.FRESH:
self._credentials.refresh(requests.Request())

metadata_task = asyncio.create_task(
self._get_metadata(
project,
region,
instance,
)
)

ephemeral_task = asyncio.create_task(
self._get_ephemeral(
project,
instance,
pub_key,
enable_iam_auth,
)
)
try:
metadata = await metadata_task
# check if automatic IAM database authn is supported for database engine
if enable_iam_auth and not metadata["database_version"].startswith(
("POSTGRES", "MYSQL")
):
raise AutoIAMAuthNotSupported(
f"'{metadata['database_version']}' does not support "
"automatic IAM authentication. It is only supported with "
"Cloud SQL Postgres or MySQL instances."
)
except Exception:
# cancel ephemeral cert task if exception occurs before it is awaited
ephemeral_task.cancel()
raise

ephemeral_cert, expiration = await ephemeral_task

return ConnectionInfo(
ephemeral_cert,
metadata["server_ca_cert"],
priv_key,
metadata["ip_addresses"],
metadata["database_version"],
expiration,
)

async def close(self) -> None:
"""Close CloudSQLClient gracefully."""
await self._client.close()
105 changes: 105 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass
import logging
import ssl
from tempfile import TemporaryDirectory
from typing import Any, Dict, Optional, TYPE_CHECKING

from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import TLSVersionError
from google.cloud.sql.connector.utils import write_to_file

if TYPE_CHECKING:
import datetime

from google.cloud.sql.connector.instance import IPTypes

logger = logging.getLogger(name=__name__)


@dataclass
class ConnectionInfo:
"""Contains all necessary information to connect securely to the
server-side Proxy running on a Cloud SQL instance."""

client_cert: str
server_ca_cert: str
private_key: bytes
ip_addrs: Dict[str, Any]
database_version: str
expiration: datetime.datetime
context: Optional[ssl.SSLContext] = None

def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
"""Constructs a SSL/TLS context for the given connection info.
Cache the SSL context to ensure we don't read from disk repeatedly when
configuring a secure connection.
"""
# if SSL context is cached, use it
if self.context is not None:
return self.context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# update ssl.PROTOCOL_TLS_CLIENT default
context.check_hostname = False

# TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
# implemented. The ssl module requires OpenSSL 1.1.1 or newer.
# verify OpenSSL version supports TLSv1.3
if ssl.HAS_TLSv1_3:
# force TLSv1.3 if supported by client
context.minimum_version = ssl.TLSVersion.TLSv1_3
# fallback to TLSv1.2 for older versions of OpenSSL
else:
if enable_iam_auth:
raise TLSVersionError(
f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not "
"support TLSv1.3, which is required to use IAM Authentication.\n"
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
)
logger.warning(
"TLSv1.3 is not supported with your version of OpenSSL "
f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n"
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
)
context.minimum_version = ssl.TLSVersion.TLSv1_2

# tmpdir and its contents are automatically deleted after the CA cert
# and ephemeral cert are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
with TemporaryDirectory() as tmpdir:
ca_filename, cert_filename, key_filename = write_to_file(
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
)
context.load_cert_chain(cert_filename, keyfile=key_filename)
context.load_verify_locations(cafile=ca_filename)
# set class attribute to cache context for subsequent calls
self.context = context
return context

def get_preferred_ip(self, ip_type: IPTypes) -> str:
"""Returns the first IP address for the instance, according to the preference
supplied by ip_type. If no IP addressess with the given preference are found,
an error is raised."""
if ip_type.value in self.ip_addrs:
return self.ip_addrs[ip_type.value]
raise CloudSQLIPTypeError(
"Cloud SQL instance does not have any IP addresses matching "
f"preference: {ip_type.value})"
)
142 changes: 9 additions & 133 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,19 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from enum import Enum
import logging
import re
import ssl
from tempfile import TemporaryDirectory
from typing import Any, Dict, Tuple, TYPE_CHECKING
from typing import Tuple

import aiohttp
from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.exceptions import TLSVersionError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from google.cloud.sql.connector.refresh_utils import _is_valid
from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh
from google.cloud.sql.connector.utils import write_to_file

if TYPE_CHECKING:
import datetime

logger = logging.getLogger(name=__name__)

Expand Down Expand Up @@ -83,79 +72,6 @@ def _from_str(cls, ip_type_str: str) -> IPTypes:
return cls(ip_type_str.upper())


@dataclass
class ConnectionInfo:
"""Contains all necessary information to connect securely to the
server-side Proxy running on a Cloud SQL instance."""

client_cert: str
server_ca_cert: str
private_key: bytes
ip_addrs: Dict[str, Any]
database_version: str
expiration: datetime.datetime
context: ssl.SSLContext | None = None

def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
"""Constructs a SSL/TLS context for the given connection info.
Cache the SSL context to ensure we don't read from disk repeatedly when
configuring a secure connection.
"""
# if SSL context is cached, use it
if self.context is not None:
return self.context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# update ssl.PROTOCOL_TLS_CLIENT default
context.check_hostname = False

# TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
# implemented. The ssl module requires OpenSSL 1.1.1 or newer.
# verify OpenSSL version supports TLSv1.3
if ssl.HAS_TLSv1_3:
# force TLSv1.3 if supported by client
context.minimum_version = ssl.TLSVersion.TLSv1_3
# fallback to TLSv1.2 for older versions of OpenSSL
else:
if enable_iam_auth:
raise TLSVersionError(
f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not "
"support TLSv1.3, which is required to use IAM Authentication.\n"
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
)
logger.warning(
"TLSv1.3 is not supported with your version of OpenSSL "
f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n"
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
)
context.minimum_version = ssl.TLSVersion.TLSv1_2

# tmpdir and its contents are automatically deleted after the CA cert
# and ephemeral cert are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
with TemporaryDirectory() as tmpdir:
ca_filename, cert_filename, key_filename = write_to_file(
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
)
context.load_cert_chain(cert_filename, keyfile=key_filename)
context.load_verify_locations(cafile=ca_filename)
# set class attribute to cache context for subsequent calls
self.context = context
return context

def get_preferred_ip(self, ip_type: IPTypes) -> str:
"""Returns the first IP address for the instance, according to the preference
supplied by ip_type. If no IP addressess with the given preference are found,
an error is raised."""
if ip_type.value in self.ip_addrs:
return self.ip_addrs[ip_type.value]
raise CloudSQLIPTypeError(
"Cloud SQL instance does not have any IP addresses matching "
f"preference: {ip_type.value})"
)


class RefreshAheadCache:
"""Cache that refreshes connection info in the background prior to expiration.
Expand Down Expand Up @@ -229,45 +145,13 @@ async def _perform_refresh(self) -> ConnectionInfo:

try:
await self._refresh_rate_limiter.acquire()
priv_key, pub_key = await self._keys

logger.debug(f"['{self._instance_connection_string}']: Creating context")

# before making Cloud SQL Admin API calls, refresh creds
if not self._client._credentials.token_state == TokenState.FRESH:
self._client._credentials.refresh(requests.Request())

metadata_task = asyncio.create_task(
self._client._get_metadata(
self._project,
self._region,
self._instance,
)
)

ephemeral_task = asyncio.create_task(
self._client._get_ephemeral(
self._project,
self._instance,
pub_key,
self._enable_iam_auth,
)
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._keys,
self._enable_iam_auth,
)
try:
metadata = await metadata_task
# check if automatic IAM database authn is supported for database engine
if self._enable_iam_auth and not metadata[
"database_version"
].startswith(("POSTGRES", "MYSQL")):
raise AutoIAMAuthNotSupported(
f"'{metadata['database_version']}' does not support automatic IAM authentication. It is only supported with Cloud SQL Postgres or MySQL instances."
)
except Exception:
# cancel ephemeral cert task if exception occurs before it is awaited
ephemeral_task.cancel()
raise

ephemeral_cert, expiration = await ephemeral_task

except aiohttp.ClientResponseError as e:
logger.debug(
Expand All @@ -285,15 +169,7 @@ async def _perform_refresh(self) -> ConnectionInfo:

finally:
self._refresh_in_progress.clear()

return ConnectionInfo(
ephemeral_cert,
metadata["server_ca_cert"],
priv_key,
metadata["ip_addresses"],
metadata["database_version"],
expiration,
)
return connection_info

def _schedule_refresh(self, delay: int) -> asyncio.Task:
"""
Expand Down
Loading

0 comments on commit e52c2d9

Please sign in to comment.