Skip to content

Commit

Permalink
add support for custom connection pool class in NodesManager (#2547)
Browse files Browse the repository at this point in the history
Co-authored-by: zach.lee <zach.lee@sendbird.com>
  • Loading branch information
zakaf and zach-iee authored Jan 11, 2023
1 parent bae6385 commit 7dd73a3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
5 changes: 4 additions & 1 deletion redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def parse_cluster_shards(resp, **options):
"charset",
"connection_class",
"connection_pool",
"connection_pool_class",
"client_name",
"credential_provider",
"db",
Expand Down Expand Up @@ -1267,6 +1268,7 @@ def __init__(
require_full_coverage=False,
lock=None,
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
**kwargs,
):
self.nodes_cache = {}
Expand All @@ -1277,6 +1279,7 @@ def __init__(
self.from_url = from_url
self._require_full_coverage = require_full_coverage
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
Expand Down Expand Up @@ -1420,7 +1423,7 @@ def create_redis_node(self, host, port, **kwargs):
# Create a redis node with a costumed connection pool
kwargs.update({"host": host})
kwargs.update({"port": port})
r = Redis(connection_pool=ConnectionPool(**kwargs))
r = Redis(connection_pool=self.connection_pool_class(**kwargs))
else:
r = Redis(host=host, port=port, **kwargs)
return r
Expand Down
17 changes: 16 additions & 1 deletion tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_node_name,
)
from redis.commands import CommandsParser
from redis.connection import Connection
from redis.connection import BlockingConnectionPool, Connection, ConnectionPool
from redis.crc import key_slot
from redis.exceptions import (
AskError,
Expand Down Expand Up @@ -2496,6 +2496,21 @@ def test_init_slots_dynamic_startup_nodes(self, dynamic_startup_nodes):
else:
assert startup_nodes == ["my@DNS.com:7000"]

@pytest.mark.parametrize(
"connection_pool_class", [ConnectionPool, BlockingConnectionPool]
)
def test_connection_pool_class(self, connection_pool_class):
rc = get_mocked_redis_client(
url="redis://my@DNS.com:7000",
cluster_slots=default_cluster_slots,
connection_pool_class=connection_pool_class,
)

for node in rc.nodes_manager.nodes_cache.values():
assert isinstance(
node.redis_connection.connection_pool, connection_pool_class
)


@pytest.mark.onlycluster
class TestClusterPubSubObject:
Expand Down

0 comments on commit 7dd73a3

Please sign in to comment.