Skip to content

Commit

Permalink
Add connect_fn kwarg to Pool to better support GCP's CloudSQL (#1170)
Browse files Browse the repository at this point in the history
Co-authored-by: Elvis Pranskevichus <elvis@edgedb.com>
  • Loading branch information
d1manson and elprans authored Oct 18, 2024
1 parent 73f2209 commit 3ee19ba
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
13 changes: 10 additions & 3 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
connect=None,
setup=None,
init=None,
loop=None,
Expand All @@ -283,12 +284,18 @@ def create_pool(dsn=None, *,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs)
**connect_kwargs,
)


class ClusterTestCase(TestCase):
Expand Down
49 changes: 41 additions & 8 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class Pool:

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect_args', '_connect_kwargs',
'_init', '_connect', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
Expand All @@ -324,8 +324,9 @@ def __init__(self, *connect_args,
max_size,
max_queries,
max_inactive_connection_lifetime,
setup,
init,
connect=None,
setup=None,
init=None,
loop,
connection_class,
record_class,
Expand Down Expand Up @@ -385,11 +386,14 @@ def __init__(self, *connect_args,
self._closing = False
self._closed = False
self._generation = 0
self._init = init

self._connect = connect if connect is not None else connection.connect
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs

self._setup = setup
self._init = init

self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
max_inactive_connection_lifetime
Expand Down Expand Up @@ -503,13 +507,25 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
self._connect_kwargs = connect_kwargs

async def _get_new_connection(self):
con = await connection.connect(
con = await self._connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)
if not isinstance(con, self._connection_class):
good = self._connection_class
good_n = f'{good.__module__}.{good.__name__}'
bad = type(con)
if bad.__module__ == "builtins":
bad_n = bad.__name__
else:
bad_n = f'{bad.__module__}.{bad.__name__}'
raise exceptions.InterfaceError(
"expected pool connect callback to return an instance of "
f"'{good_n}', got " f"'{bad_n}'"
)

if self._init is not None:
try:
Expand Down Expand Up @@ -1017,6 +1033,7 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
connect=None,
setup=None,
init=None,
loop=None,
Expand Down Expand Up @@ -1099,6 +1116,13 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.
:param coroutine connect:
A coroutine that is called instead of
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
new connection. Must return an instance of type specified by
*connection_class* or :class:`~asyncpg.connection.Connection` if
*connection_class* was not specified.
:param coroutine setup:
A coroutine to prepare a connection right before it is returned
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
Expand Down Expand Up @@ -1139,12 +1163,21 @@ def create_pool(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
.. versionchanged:: 0.30.0
Added the *connect* parameter.
"""
return Pool(
dsn,
connection_class=connection_class,
record_class=record_class,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs)
**connect_kwargs,
)
21 changes: 20 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ async def setup(con):

async def test_pool_07(self):
cons = set()
connect_called = 0

async def connect(*args, **kwargs):
nonlocal connect_called
connect_called += 1
return await pg_connection.connect(*args, **kwargs)

async def setup(con):
if con._con not in cons: # `con` is `PoolConnectionProxy`.
Expand All @@ -152,13 +158,26 @@ async def user(pool):
raise RuntimeError('init was not called')

async with self.create_pool(database='postgres',
min_size=2, max_size=5,
min_size=2,
max_size=5,
connect=connect,
init=init,
setup=setup) as pool:
users = asyncio.gather(*[user(pool) for _ in range(10)])
await users

self.assertEqual(len(cons), 5)
self.assertEqual(connect_called, 5)

async def bad_connect(*args, **kwargs):
return 1

with self.assertRaisesRegex(
asyncpg.InterfaceError,
"expected pool connect callback to return an instance of "
"'asyncpg\\.connection\\.Connection', got 'int'"
):
await self.create_pool(database='postgres', connect=bad_connect)

async def test_pool_08(self):
pool = await self.create_pool(database='postgres',
Expand Down

0 comments on commit 3ee19ba

Please sign in to comment.