From dc2a15f988bb2b3597eb4cd627c6e29a151accd4 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 3 Oct 2024 14:24:36 +0000 Subject: [PATCH 1/2] feat: improve caching of connection info caches --- google/cloud/sql/connector/connector.py | 29 ++++++------ tests/unit/test_connector.py | 62 ++++++++++++++++--------- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 5bb0bfc9..51107f2c 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -21,7 +21,7 @@ import logging from threading import Thread from types import TracebackType -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, Union import google.auth from google.auth.credentials import Credentials @@ -131,7 +131,11 @@ def __init__( asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), loop=self._loop, ) - self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} + # initialize dict to store caches, key is a tuple consisting of instance + # connection name string and enable_iam_auth boolean flag + self._cache: Dict[ + Tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache] + ] = {} self._client: Optional[CloudSQLClient] = None # initialize credentials @@ -262,15 +266,8 @@ async def connect_async( driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - if instance_connection_string in self._cache: - cache = self._cache[instance_connection_string] - if enable_iam_auth != cache._enable_iam_auth: - raise ValueError( - f"connect() called with 'enable_iam_auth={enable_iam_auth}', " - f"but previously used 'enable_iam_auth={cache._enable_iam_auth}'. " - "If you require both for your use case, please use a new " - "connector.Connector object." - ) + if (instance_connection_string, enable_iam_auth) in self._cache: + cache = self._cache[(instance_connection_string, enable_iam_auth)] else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( @@ -297,7 +294,7 @@ async def connect_async( logger.debug( f"['{instance_connection_string}']: Connection info added to cache" ) - self._cache[instance_connection_string] = cache + self._cache[(instance_connection_string, enable_iam_auth)] = cache connect_func = { "pymysql": pymysql.connect, @@ -333,7 +330,7 @@ async def connect_async( except Exception: # with an error from Cloud SQL Admin API call or IP type, invalidate # the cache and re-raise the error - await self._remove_cached(instance_connection_string) + await self._remove_cached(instance_connection_string, enable_iam_auth) raise logger.debug( f"['{instance_connection_string}']: Connecting to {ip_address}:3307" @@ -370,7 +367,9 @@ async def connect_async( await cache.force_refresh() raise - async def _remove_cached(self, instance_connection_string: str) -> None: + async def _remove_cached( + self, instance_connection_string: str, enable_iam_auth: bool + ) -> None: """Stops all background refreshes and deletes the connection info cache from the map of caches. """ @@ -378,7 +377,7 @@ async def _remove_cached(self, instance_connection_string: str) -> None: f"['{instance_connection_string}']: Removing connection info from cache" ) # remove cache from stored caches and close it - cache = self._cache.pop(instance_connection_string) + cache = self._cache.pop((instance_connection_string, enable_iam_auth)) await cache.close() def __enter__(self) -> Any: diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index da463067..fd18f2d5 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -31,26 +31,46 @@ from google.cloud.sql.connector.instance import RefreshAheadCache -def test_connect_enable_iam_auth_error( - fake_credentials: Credentials, cache: RefreshAheadCache +@pytest.mark.asyncio +async def test_connect_enable_iam_auth_error( + fake_credentials: Credentials, fake_client: CloudSQLClient ) -> None: """Test that calling connect() with different enable_iam_auth - argument values throws error.""" + argument values creates two cache entries.""" connect_string = "test-project:test-region:test-instance" - connector = Connector(credentials=fake_credentials) - # set cache - connector._cache[connect_string] = cache - # try to connect using enable_iam_auth=True, should raise error - with pytest.raises(ValueError) as exc_info: - connector.connect(connect_string, "pg8000", enable_iam_auth=True) - assert ( - exc_info.value.args[0] == "connect() called with 'enable_iam_auth=True', " - "but previously used 'enable_iam_auth=False'. " - "If you require both for your use case, please use a new " - "connector.Connector object." - ) - # remove cache entry to avoid destructor warnings - connector._cache = {} + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + mock_connect.return_value = True + # connect with enable_iam_auth False + connection = await connector.connect_async( + connect_string, + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + enable_iam_auth=False, + ) + # verify connector made connection call + assert connection is True + # connect with enable_iam_auth True + connection = await connector.connect_async( + connect_string, + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + enable_iam_auth=True, + ) + # verify connector made connection call + assert connection is True + # verify both cache entries for same instance exist + assert len(connector._cache) == 2 + assert (connect_string, True) in connector._cache + assert (connect_string, False) in connector._cache async def test_connect_incompatible_driver_error( @@ -305,7 +325,7 @@ async def test_Connector_remove_cached_bad_instance( conn_name = "bad-project:bad-region:bad-inst" # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[conn_name] = cache + connector._cache[(conn_name, False)] = cache # aiohttp client should throw a 404 ClientResponseError with pytest.raises(ClientResponseError): await connector.connect_async( @@ -313,7 +333,7 @@ async def test_Connector_remove_cached_bad_instance( "pg8000", ) # check that cache has been removed from dict - assert conn_name not in connector._cache + assert (conn_name, False) not in connector._cache async def test_Connector_remove_cached_no_ip_type( @@ -331,7 +351,7 @@ async def test_Connector_remove_cached_no_ip_type( conn_name = "test-project:test-region:test-instance" # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[conn_name] = cache + connector._cache[(conn_name, False)] = cache # test instance does not have Private IP, thus should invalidate cache with pytest.raises(CloudSQLIPTypeError): await connector.connect_async( @@ -342,7 +362,7 @@ async def test_Connector_remove_cached_no_ip_type( ip_type="private", ) # check that cache has been removed from dict - assert conn_name not in connector._cache + assert (conn_name, False) not in connector._cache def test_default_universe_domain(fake_credentials: Credentials) -> None: From f87fdfe3c08b1184a8b04a377be96a26fe874359 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 3 Oct 2024 14:39:47 +0000 Subject: [PATCH 2/2] chore: update system test --- tests/system/test_connector_object.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index 8d3a2df2..c2b5cf12 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -79,11 +79,21 @@ def test_multiple_connectors() -> None: conn.execute(sqlalchemy.text("SELECT 1")) instance_connection_string = os.environ["MYSQL_CONNECTION_NAME"] - assert instance_connection_string in first_connector._cache - assert instance_connection_string in second_connector._cache assert ( - first_connector._cache[instance_connection_string] - != second_connector._cache[instance_connection_string] + instance_connection_string, + first_connector._enable_iam_auth, + ) in first_connector._cache + assert ( + instance_connection_string, + second_connector._enable_iam_auth, + ) in second_connector._cache + assert ( + first_connector._cache[ + (instance_connection_string, first_connector._enable_iam_auth) + ] + != second_connector._cache[ + (instance_connection_string, second_connector._enable_iam_auth) + ] ) except Exception: raise