Skip to content

Commit

Permalink
Allow customizing connection state reset (#1191)
Browse files Browse the repository at this point in the history
A coroutine can be passed to the new `reset` argument of `create_pool`
to control what happens to the connection when it is returned back to
the pool by `release()`.  By default `Connection.reset()` is called.
Additionally, `Connection.get_reset_query` is renamed from
`Connection._get_reset_query` to enable an alternative way of
customizing the reset process via subclassing.

Closes: #780
Closes: #1146
  • Loading branch information
elprans authored Oct 18, 2024
1 parent 3ef884e commit f6ec755
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 11 deletions.
45 changes: 39 additions & 6 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,11 +1515,10 @@ def terminate(self):
self._abort()
self._cleanup()

async def reset(self, *, timeout=None):
async def _reset(self):
self._check_open()
self._listeners.clear()
self._log_listeners.clear()
reset_query = self._get_reset_query()

if self._protocol.is_in_transaction() or self._top_xact is not None:
if self._top_xact is None or not self._top_xact._managed:
Expand All @@ -1531,10 +1530,36 @@ async def reset(self, *, timeout=None):
})

self._top_xact = None
reset_query = 'ROLLBACK;\n' + reset_query
await self.execute("ROLLBACK")

async def reset(self, *, timeout=None):
"""Reset the connection state.
Calling this will reset the connection session state to a state
resembling that of a newly obtained connection. Namely, an open
transaction (if any) is rolled back, open cursors are closed,
all `LISTEN <https://www.postgresql.org/docs/current/sql-listen.html>`_
registrations are removed, all session configuration
variables are reset to their default values, and all advisory locks
are released.
Note that the above describes the default query returned by
:meth:`Connection.get_reset_query`. If one overloads the method
by subclassing ``Connection``, then this method will do whatever
the overloaded method returns, except open transactions are always
terminated and any callbacks registered by
:meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
are removed.
if reset_query:
await self.execute(reset_query, timeout=timeout)
:param float timeout:
A timeout for resetting the connection. If not specified, defaults
to no timeout.
"""
async with compat.timeout(timeout):
await self._reset()
reset_query = self.get_reset_query()
if reset_query:
await self.execute(reset_query)

def _abort(self):
# Put the connection into the aborted state.
Expand Down Expand Up @@ -1695,7 +1720,15 @@ def _unwrap(self):
con_ref = self._proxy
return con_ref

def _get_reset_query(self):
def get_reset_query(self):
"""Return the query sent to server on connection release.
The query returned by this method is used by :meth:`Connection.reset`,
which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
the connection available to another acquirer.
.. versionadded:: 0.30.0
"""
if self._reset_query is not None:
return self._reset_query

Expand Down
36 changes: 32 additions & 4 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ async def release(self, timeout):
if budget is not None:
budget -= time.monotonic() - started

await self._con.reset(timeout=budget)
if self._pool._reset is not None:
async with compat.timeout(budget):
await self._con._reset()
await self._pool._reset(self._con)
else:
await self._con.reset(timeout=budget)
except (Exception, asyncio.CancelledError) as ex:
# If the `reset` call failed, terminate the connection.
# A new one will be created when `acquire` is called
Expand Down Expand Up @@ -313,7 +318,7 @@ class Pool:

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect', '_connect_args', '_connect_kwargs',
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
Expand All @@ -327,6 +332,7 @@ def __init__(self, *connect_args,
connect=None,
setup=None,
init=None,
reset=None,
loop,
connection_class,
record_class,
Expand Down Expand Up @@ -393,6 +399,7 @@ def __init__(self, *connect_args,

self._setup = setup
self._init = init
self._reset = reset

self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
Expand Down Expand Up @@ -1036,6 +1043,7 @@ def create_pool(dsn=None, *,
connect=None,
setup=None,
init=None,
reset=None,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
Expand Down Expand Up @@ -1125,7 +1133,7 @@ def create_pool(dsn=None, *,
:param coroutine setup:
A coroutine to prepare a connection right before it is returned
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
from :meth:`Pool.acquire()`. An example use
case would be to automatically set up notifications listeners for
all connections of a pool.
Expand All @@ -1137,6 +1145,25 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.
:param coroutine reset:
A coroutine to reset a connection before it is returned to the pool by
:meth:`Pool.release()`. The function is supposed
to reset any changes made to the database session so that the next
acquirer gets the connection in a well-defined state.
The default implementation calls :meth:`Connection.reset() <\
asyncpg.connection.Connection.reset>`, which runs the following::
SELECT pg_advisory_unlock_all();
CLOSE ALL;
UNLISTEN *;
RESET ALL;
The exact reset query is determined by detected server capabilities,
and a custom *reset* implementation can obtain the default query
by calling :meth:`Connection.get_reset_query() <\
asyncpg.connection.Connection.get_reset_query>`.
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -1165,7 +1192,7 @@ def create_pool(dsn=None, *,
Added the *record_class* parameter.
.. versionchanged:: 0.30.0
Added the *connect* parameter.
Added the *connect* and *reset* parameters.
"""
return Pool(
dsn,
Expand All @@ -1178,6 +1205,7 @@ def create_pool(dsn=None, *,
connect=connect,
setup=setup,
init=init,
reset=reset,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs,
)
17 changes: 16 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,31 @@ async def setup(con):
async def test_pool_07(self):
cons = set()
connect_called = 0
init_called = 0
setup_called = 0
reset_called = 0

async def connect(*args, **kwargs):
nonlocal connect_called
connect_called += 1
return await pg_connection.connect(*args, **kwargs)

async def setup(con):
nonlocal setup_called
if con._con not in cons: # `con` is `PoolConnectionProxy`.
raise RuntimeError('init was not called before setup')
setup_called += 1

async def init(con):
nonlocal init_called
if con in cons:
raise RuntimeError('init was called more than once')
cons.add(con)
init_called += 1

async def reset(con):
nonlocal reset_called
reset_called += 1

async def user(pool):
async with pool.acquire() as con:
Expand All @@ -162,12 +173,16 @@ async def user(pool):
max_size=5,
connect=connect,
init=init,
setup=setup) as pool:
setup=setup,
reset=reset) as pool:
users = asyncio.gather(*[user(pool) for _ in range(10)])
await users

self.assertEqual(len(cons), 5)
self.assertEqual(connect_called, 5)
self.assertEqual(init_called, 5)
self.assertEqual(setup_called, 10)
self.assertEqual(reset_called, 10)

async def bad_connect(*args, **kwargs):
return 1
Expand Down

0 comments on commit f6ec755

Please sign in to comment.