diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index f77296df6a..5de04c0f94 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -138,12 +138,6 @@ def __init__(self, socket_read_size: int): self._stream: Optional[StreamReader] = None self._read_size = socket_read_size - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - async def can_read_destructive(self) -> bool: raise NotImplementedError() diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index c340d851b1..381df50ccc 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -227,7 +227,6 @@ def __init__( lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, - # deprecated. create a pool and use connection_pool instead auto_close_connection_pool: Optional[bool] = None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, @@ -241,7 +240,9 @@ def __init__( To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ kwargs: Dict[str, Any] - + # auto_close_connection_pool only has an effect if connection_pool is + # None. It is assumed that if connection_pool is not None, the user + # wants to manage the connection pool themselves. if auto_close_connection_pool is not None: warnings.warn( DeprecationWarning( @@ -531,13 +532,20 @@ async def __aexit__(self, exc_type, exc_value, traceback): _DEL_MESSAGE = "Unclosed Redis client" - def __del__(self, _warnings: Any = warnings) -> None: + # passing _warnings and _grl as argument default since they may be gone + # by the time __del__ is called at shutdown + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: if hasattr(self, "connection") and (self.connection is not None): - _warnings.warn( - f"Unclosed client session {self!r}", ResourceWarning, source=self - ) - context = {"client": self, "message": self._DEL_MESSAGE} - asyncio.get_running_loop().call_exception_handler(context) + _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + _grl().call_exception_handler(context) + except RuntimeError: + pass async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: """ @@ -786,7 +794,7 @@ async def aclose(self): async with self._lock: if self.connection: await self.connection.disconnect() - self.connection.clear_connect_callbacks() + self.connection._deregister_connect_callback(self.on_connect) await self.connection_pool.release(self.connection) self.connection = None self.channels = {} @@ -849,7 +857,7 @@ async def connect(self): ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) + self.connection._register_connect_callback(self.on_connect) else: await self.connection.connect() if self.push_handler_func is not None and not HIREDIS_AVAILABLE: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index f4f031580d..15634de81f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -433,14 +433,18 @@ def __await__(self) -> Generator[Any, None, "RedisCluster"]: _DEL_MESSAGE = "Unclosed RedisCluster client" - def __del__(self) -> None: + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: if hasattr(self, "_initialize") and not self._initialize: - warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) + _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) try: context = {"client": self, "message": self._DEL_MESSAGE} - asyncio.get_running_loop().call_exception_handler(context) + _grl().call_exception_handler(context) except RuntimeError: - ... + pass async def on_connect(self, connection: Connection) -> None: await connection.on_connect() @@ -969,17 +973,20 @@ def __eq__(self, obj: Any) -> bool: _DEL_MESSAGE = "Unclosed ClusterNode object" - def __del__(self) -> None: + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: for connection in self._connections: if connection.is_connected: - warnings.warn( - f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self - ) + _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) + try: context = {"client": self, "message": self._DEL_MESSAGE} - asyncio.get_running_loop().call_exception_handler(context) + _grl().call_exception_handler(context) except RuntimeError: - ... + pass break async def disconnect(self) -> None: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 71d0e92002..f36b4bf79b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -216,11 +216,16 @@ def repr_pieces(self): def is_connected(self): return self._reader is not None and self._writer is not None - def register_connect_callback(self, callback): - self._connect_callbacks.append(weakref.WeakMethod(callback)) + def _register_connect_callback(self, callback): + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) - def clear_connect_callbacks(self): - self._connect_callbacks = [] + def _deregister_connect_callback(self, callback): + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass def set_parser(self, parser_class: Type[BaseParser]) -> None: """ @@ -263,6 +268,8 @@ async def connect(self): # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] for ref in self._connect_callbacks: callback = ref() task = callback(self) diff --git a/redis/client.py b/redis/client.py index 1e1ff57605..cf6dbf1eed 100755 --- a/redis/client.py +++ b/redis/client.py @@ -690,7 +690,7 @@ def __del__(self): def reset(self): if self.connection: self.connection.disconnect() - self.connection.clear_connect_callbacks() + self.connection._deregister_connect_callback(self.on_connect) self.connection_pool.release(self.connection) self.connection = None self.health_check_response_counter = 0 @@ -748,7 +748,7 @@ def execute_command(self, *args): ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) + self.connection._register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection diff --git a/redis/cluster.py b/redis/cluster.py index 2ce9c54f85..ee3e1a865d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1775,7 +1775,7 @@ def execute_command(self, *args): ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) + self.connection._register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection diff --git a/redis/connection.py b/redis/connection.py index 45ecd2a370..f5266d7dce 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -237,11 +237,16 @@ def _construct_command_packer(self, packer): else: return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) - def register_connect_callback(self, callback): - self._connect_callbacks.append(weakref.WeakMethod(callback)) + def _register_connect_callback(self, callback): + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) - def clear_connect_callbacks(self): - self._connect_callbacks = [] + def _deregister_connect_callback(self, callback): + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass def set_parser(self, parser_class): """ @@ -279,6 +284,8 @@ def connect(self): # run any user callbacks. right now the only internal callback # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] for ref in self._connect_callbacks: callback = ref() if callback: diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index e784690c77..711b3ee733 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -10,11 +10,11 @@ pytestmark = pytest.mark.asyncio -async def test_connect_retry_on_timeout_error(): +async def test_connect_retry_on_timeout_error(connect_args): """Test that the _connect function is retried in case of a timeout""" connection_pool = mock.AsyncMock() connection_pool.get_master_address = mock.AsyncMock( - return_value=("localhost", 6379) + return_value=(connect_args["host"], connect_args["port"]) ) conn = SentinelManagedConnection( retry_on_timeout=True,