Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use asyncio.Lock over Event #1095

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 46 additions & 46 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
max_capacity=2,
rate=1 / 30,
)
self._refresh_in_progress = asyncio.locks.Event()
self._lock = asyncio.Lock()
self._current: asyncio.Task = self._schedule_refresh(0)
self._next: asyncio.Task = self._current

Expand All @@ -103,12 +103,14 @@ async def force_refresh(self) -> None:
Forces a new refresh attempt immediately to be used for future connection attempts.
"""
# if next refresh is not already in progress, cancel it and schedule new one immediately
if not self._refresh_in_progress.is_set():
self._next.cancel()
self._next = self._schedule_refresh(0)
if not self._lock.locked():
async with self._lock:
self._next.cancel()
self._next = self._schedule_refresh(0)
# block all sequential connection attempts on the next refresh result if current is invalid
if not await _is_valid(self._current):
self._current = self._next
async with self._lock:
if not await _is_valid(self._current):
self._current = self._next

async def _perform_refresh(self) -> ConnectionInfo:
"""Retrieves instance metadata and ephemeral certificate from the
Expand All @@ -120,48 +122,45 @@ async def _perform_refresh(self) -> ConnectionInfo:
a string representing a PEM-encoded private key and a string
representing a PEM-encoded certificate authority.
"""
self._refresh_in_progress.set()
logger.debug(
f"['{self._instance_connection_string}']: Connection info refresh "
"operation started"
)

try:
await self._refresh_rate_limiter.acquire()
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._keys,
self._enable_iam_auth,
)
async with self._lock:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation complete"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"expiration = {connection_info.expiration.isoformat()}"
"refresh operation started"
)

except aiohttp.ClientResponseError as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"refresh operation failed: {str(e)}"
)
if e.status == 403:
e.message = "Forbidden: Authenticated IAM principal does not seeem authorized to make API request. Verify 'Cloud SQL Admin API' is enabled within your GCP project and 'Cloud SQL Client' role has been granted to IAM principal."
raise
try:
await self._refresh_rate_limiter.acquire()
connection_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._keys,
self._enable_iam_auth,
)
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation complete"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"expiration = {connection_info.expiration.isoformat()}"
)

except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
except aiohttp.ClientResponseError as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"refresh operation failed: {str(e)}"
)
if e.status == 403:
e.message = "Forbidden: Authenticated IAM principal does not seeem authorized to make API request. Verify 'Cloud SQL Admin API' is enabled within your GCP project and 'Cloud SQL Client' role has been granted to IAM principal."
raise

finally:
self._refresh_in_progress.clear()
except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
return connection_info

def _schedule_refresh(self, delay: int) -> asyncio.Task:
Expand Down Expand Up @@ -244,8 +243,9 @@ async def close(self) -> None:
f"['{self._instance_connection_string}']: Canceling connection info "
"refresh operation tasks"
)
self._current.cancel()
self._next.cancel()
# gracefully wait for tasks to cancel
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
async with self._lock:
self._current.cancel()
self._next.cancel()
# gracefully wait for tasks to cancel
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
await asyncio.wait_for(tasks, timeout=2.0)
8 changes: 4 additions & 4 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ async def test_force_refresh_cancels_pending_refresh(
test_rate_limiter: AsyncRateLimiter,
) -> None:
"""
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
Test that force_refresh cancels pending task if lock is not acquired.
"""
# allow more frequent refreshes for tests
setattr(cache, "_refresh_rate_limiter", test_rate_limiter)
# make sure initial refresh is finished
await cache._current
# since the pending refresh isn't for another 55 min, the refresh_in_progress event
# shouldn't be set
# since the pending refresh isn't for another 55 min, the lock should not
# be acquired
pending_refresh = cache._next
assert cache._refresh_in_progress.is_set() is False
assert cache._lock.locked() is False

await cache.force_refresh()

Expand Down
Loading