Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use asyncio.Lock over Event #1095

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
max_capacity=2,
rate=1 / 30,
)
self._refresh_in_progress = asyncio.locks.Event()
self._lock = asyncio.Lock()
self._current: asyncio.Task = self._schedule_refresh(0)
self._next: asyncio.Task = self._current

Expand All @@ -138,7 +138,7 @@ async def force_refresh(self) -> None:
Forces a new refresh attempt immediately to be used for future connection attempts.
"""
# if next refresh is not already in progress, cancel it and schedule new one immediately
if not self._refresh_in_progress.is_set():
if not self._lock.locked():
self._next.cancel()
self._next = self._schedule_refresh(0)
Copy link
Contributor

@hessjcg hessjcg Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Go and Java connectors use the mutex to protect access and modification of self._next and self._current to avoid race conditions that occur when there are concurrent requests to refresh from multiple threads. The python connector probably needs to do the same. See Go RefreshAheadCache instantiation and scheduleRefresh()

# block all sequential connection attempts on the next refresh result if current is invalid
Expand All @@ -155,37 +155,34 @@ async def _perform_refresh(self) -> ConnectionInfo:
a string representing a PEM-encoded private key and a string
representing a PEM-encoded certificate authority.
"""
self._refresh_in_progress.set()
logger.debug(
f"['{self._instance_connection_string}']: Entered _perform_refresh"
)

try:
await self._refresh_rate_limiter.acquire()
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._keys,
self._enable_iam_auth,
)

except aiohttp.ClientResponseError as e:
async with self._lock:
logger.debug(
f"['{self._instance_connection_string}']: Error occurred during _perform_refresh."
f"['{self._instance_connection_string}']: Entered _perform_refresh"
)
if e.status == 403:
e.message = "Forbidden: Authenticated IAM principal does not seeem authorized to make API request. Verify 'Cloud SQL Admin API' is enabled within your GCP project and 'Cloud SQL Client' role has been granted to IAM principal."
raise

except Exception:
logger.debug(
f"['{self._instance_connection_string}']: Error occurred during _perform_refresh."
)
raise
try:
await self._refresh_rate_limiter.acquire()
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._keys,
self._enable_iam_auth,
)

except aiohttp.ClientResponseError as e:
logger.debug(
f"['{self._instance_connection_string}']: Error occurred during _perform_refresh."
)
if e.status == 403:
e.message = "Forbidden: Authenticated IAM principal does not seeem authorized to make API request. Verify 'Cloud SQL Admin API' is enabled within your GCP project and 'Cloud SQL Client' role has been granted to IAM principal."
raise

finally:
self._refresh_in_progress.clear()
except Exception:
logger.debug(
f"['{self._instance_connection_string}']: Error occurred during _perform_refresh."
)
raise
return connection_info

def _schedule_refresh(self, delay: int) -> asyncio.Task:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ async def test_force_refresh_cancels_pending_refresh(
test_rate_limiter: AsyncRateLimiter,
) -> None:
"""
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
Test that force_refresh cancels pending task if lock is not acquired.
"""
# allow more frequent refreshes for tests
setattr(cache, "_refresh_rate_limiter", test_rate_limiter)
# make sure initial refresh is finished
await cache._current
# since the pending refresh isn't for another 55 min, the refresh_in_progress event
# shouldn't be set
# since the pending refresh isn't for another 55 min, the lock should not
# be acquired
pending_refresh = cache._next
assert cache._refresh_in_progress.is_set() is False
assert cache._lock.locked() is False

await cache.force_refresh()

Expand Down
Loading