Skip to content

Commit

Permalink
async fixes, remove __del__ and other things (#2870)
Browse files Browse the repository at this point in the history
* Use correct redis url if not default when creating Connection

* Make resource-warning __del__ code safer during shutdown

* Remove __del__ handler, fix pubsub weakref callback handling

* Clarify comment, since there is no __del__ on asyncio.connection.ConnectionPool

* Remove remaining __del__ from async parser.  They are not needed.

* make connect callback methods internal

* similarly make non-async connect callbacks internal, use same system as for async.

* Reformat __del__()
  • Loading branch information
kristjanvalur authored Sep 20, 2023
1 parent c46a28d commit ded9f7c
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 39 deletions.
6 changes: 0 additions & 6 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,6 @@ def __init__(self, socket_read_size: int):
self._stream: Optional[StreamReader] = None
self._read_size = socket_read_size

def __del__(self):
try:
self.on_disconnect()
except Exception:
pass

async def can_read_destructive(self) -> bool:
raise NotImplementedError()

Expand Down
28 changes: 18 additions & 10 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def __init__(
lib_version: Optional[str] = get_lib_version(),
username: Optional[str] = None,
retry: Optional[Retry] = None,
# deprecated. create a pool and use connection_pool instead
auto_close_connection_pool: Optional[bool] = None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
Expand All @@ -241,7 +240,9 @@ def __init__(
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]

# auto_close_connection_pool only has an effect if connection_pool is
# None. It is assumed that if connection_pool is not None, the user
# wants to manage the connection pool themselves.
if auto_close_connection_pool is not None:
warnings.warn(
DeprecationWarning(
Expand Down Expand Up @@ -531,13 +532,20 @@ async def __aexit__(self, exc_type, exc_value, traceback):

_DEL_MESSAGE = "Unclosed Redis client"

def __del__(self, _warnings: Any = warnings) -> None:
# passing _warnings and _grl as argument default since they may be gone
# by the time __del__ is called at shutdown
def __del__(
self,
_warn: Any = warnings.warn,
_grl: Any = asyncio.get_running_loop,
) -> None:
if hasattr(self, "connection") and (self.connection is not None):
_warnings.warn(
f"Unclosed client session {self!r}", ResourceWarning, source=self
)
context = {"client": self, "message": self._DEL_MESSAGE}
asyncio.get_running_loop().call_exception_handler(context)
_warn(f"Unclosed client session {self!r}", ResourceWarning, source=self)
try:
context = {"client": self, "message": self._DEL_MESSAGE}
_grl().call_exception_handler(context)
except RuntimeError:
pass

async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Expand Down Expand Up @@ -786,7 +794,7 @@ async def aclose(self):
async with self._lock:
if self.connection:
await self.connection.disconnect()
self.connection.clear_connect_callbacks()
self.connection._deregister_connect_callback(self.on_connect)
await self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
Expand Down Expand Up @@ -849,7 +857,7 @@ async def connect(self):
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
self.connection._register_connect_callback(self.on_connect)
else:
await self.connection.connect()
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
Expand Down
27 changes: 17 additions & 10 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,18 @@ def __await__(self) -> Generator[Any, None, "RedisCluster"]:

_DEL_MESSAGE = "Unclosed RedisCluster client"

def __del__(self) -> None:
def __del__(
self,
_warn: Any = warnings.warn,
_grl: Any = asyncio.get_running_loop,
) -> None:
if hasattr(self, "_initialize") and not self._initialize:
warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
_warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
try:
context = {"client": self, "message": self._DEL_MESSAGE}
asyncio.get_running_loop().call_exception_handler(context)
_grl().call_exception_handler(context)
except RuntimeError:
...
pass

async def on_connect(self, connection: Connection) -> None:
await connection.on_connect()
Expand Down Expand Up @@ -969,17 +973,20 @@ def __eq__(self, obj: Any) -> bool:

_DEL_MESSAGE = "Unclosed ClusterNode object"

def __del__(self) -> None:
def __del__(
self,
_warn: Any = warnings.warn,
_grl: Any = asyncio.get_running_loop,
) -> None:
for connection in self._connections:
if connection.is_connected:
warnings.warn(
f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self
)
_warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)

try:
context = {"client": self, "message": self._DEL_MESSAGE}
asyncio.get_running_loop().call_exception_handler(context)
_grl().call_exception_handler(context)
except RuntimeError:
...
pass
break

async def disconnect(self) -> None:
Expand Down
15 changes: 11 additions & 4 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,16 @@ def repr_pieces(self):
def is_connected(self):
return self._reader is not None and self._writer is not None

def register_connect_callback(self, callback):
self._connect_callbacks.append(weakref.WeakMethod(callback))
def _register_connect_callback(self, callback):
wm = weakref.WeakMethod(callback)
if wm not in self._connect_callbacks:
self._connect_callbacks.append(wm)

def clear_connect_callbacks(self):
self._connect_callbacks = []
def _deregister_connect_callback(self, callback):
try:
self._connect_callbacks.remove(weakref.WeakMethod(callback))
except ValueError:
pass

def set_parser(self, parser_class: Type[BaseParser]) -> None:
"""
Expand Down Expand Up @@ -263,6 +268,8 @@ async def connect(self):

# run any user callbacks. right now the only internal callback
# is for pubsub channel/pattern resubscription
# first, remove any dead weakrefs
self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
for ref in self._connect_callbacks:
callback = ref()
task = callback(self)
Expand Down
4 changes: 2 additions & 2 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def __del__(self):
def reset(self):
if self.connection:
self.connection.disconnect()
self.connection.clear_connect_callbacks()
self.connection._deregister_connect_callback(self.on_connect)
self.connection_pool.release(self.connection)
self.connection = None
self.health_check_response_counter = 0
Expand Down Expand Up @@ -748,7 +748,7 @@ def execute_command(self, *args):
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
self.connection._register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
Expand Down
2 changes: 1 addition & 1 deletion redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,7 +1775,7 @@ def execute_command(self, *args):
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
self.connection._register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
Expand Down
15 changes: 11 additions & 4 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,16 @@ def _construct_command_packer(self, packer):
else:
return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode)

def register_connect_callback(self, callback):
self._connect_callbacks.append(weakref.WeakMethod(callback))
def _register_connect_callback(self, callback):
wm = weakref.WeakMethod(callback)
if wm not in self._connect_callbacks:
self._connect_callbacks.append(wm)

def clear_connect_callbacks(self):
self._connect_callbacks = []
def _deregister_connect_callback(self, callback):
try:
self._connect_callbacks.remove(weakref.WeakMethod(callback))
except ValueError:
pass

def set_parser(self, parser_class):
"""
Expand Down Expand Up @@ -279,6 +284,8 @@ def connect(self):

# run any user callbacks. right now the only internal callback
# is for pubsub channel/pattern resubscription
# first, remove any dead weakrefs
self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
for ref in self._connect_callbacks:
callback = ref()
if callback:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_asyncio/test_sentinel_managed_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
pytestmark = pytest.mark.asyncio


async def test_connect_retry_on_timeout_error():
async def test_connect_retry_on_timeout_error(connect_args):
"""Test that the _connect function is retried in case of a timeout"""
connection_pool = mock.AsyncMock()
connection_pool.get_master_address = mock.AsyncMock(
return_value=("localhost", 6379)
return_value=(connect_args["host"], connect_args["port"])
)
conn = SentinelManagedConnection(
retry_on_timeout=True,
Expand Down

0 comments on commit ded9f7c

Please sign in to comment.