diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 501e234c3c..5ed924096c 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -335,11 +335,15 @@ def master_for( kwargs["is_master"] = True connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) + + connection_pool = connection_pool_class(service_name, self, **connection_kwargs) + # The Redis object "owns" the pool + auto_close_connection_pool = True + client = redis_class( + connection_pool=connection_pool, ) + client.auto_close_connection_pool = auto_close_connection_pool + return client def slave_for( self, @@ -368,8 +372,12 @@ def slave_for( kwargs["is_master"] = False connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) + + connection_pool = connection_pool_class(service_name, self, **connection_kwargs) + # The Redis object "owns" the pool + auto_close_connection_pool = True + client = redis_class( + connection_pool=connection_pool, ) + client.auto_close_connection_pool = auto_close_connection_pool + return client diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 2091f2cb87..a2d52f17b7 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -1,4 +1,5 @@ import socket +from unittest import mock import pytest import pytest_asyncio @@ -239,3 +240,28 @@ async def test_flushconfig(cluster, sentinel): async def test_reset(cluster, sentinel): cluster.master["is_odown"] = True assert await sentinel.sentinel_reset("mymaster") + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("method_name", ["master_for", "slave_for"]) +async def test_auto_close_pool(cluster, sentinel, method_name): + """ + Check that the connection pool created by the sentinel client is + automatically closed + """ + + method = getattr(sentinel, method_name) + client = method("mymaster", db=9) + pool = client.connection_pool + assert client.auto_close_connection_pool is True + calls = 0 + + async def mock_disconnect(): + nonlocal calls + calls += 1 + + with mock.patch.object(pool, "disconnect", mock_disconnect): + await client.close() + + assert calls == 1 + await pool.disconnect()