Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(_redis.py): allow all supported arguments for redis cluster #5554

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 7 additions & 27 deletions litellm/_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
from typing import List, Optional
from typing import List

import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
Expand Down Expand Up @@ -55,19 +55,6 @@ def _get_redis_url_kwargs(client=None):
return available_args


def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)

# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}

available_args = [x for x in arg_spec.args if x not in exclude_args]

return available_args


def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"

Expand Down Expand Up @@ -143,13 +130,10 @@ def get_redis_client(**env_overrides):
return redis.Redis.from_url(**url_kwargs)

if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
from redis.cluster import ClusterNode, cleanup_kwargs

args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
# Only allow primitive arguments
cluster_kwargs = cleanup_kwargs(**redis_kwargs)

new_startup_nodes: List[ClusterNode] = []

Expand Down Expand Up @@ -177,14 +161,10 @@ def get_redis_async_client(**env_overrides):
return async_redis.Redis.from_url(**url_kwargs)

if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode

args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
from redis.cluster import ClusterNode, cleanup_kwargs

# Only allow primitive arguments
cluster_kwargs = cleanup_kwargs(**redis_kwargs)
new_startup_nodes: List[ClusterNode] = []

for item in redis_kwargs["startup_nodes"]:
Expand Down
2 changes: 1 addition & 1 deletion litellm/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ async def test_redis_cache_cluster_init_unit_test():
assert isinstance(resp.redis_client, RedisCluster)
assert isinstance(resp.init_async_client(), AsyncRedisCluster)

resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes, password="my-cluster-password")

assert isinstance(resp.cache, RedisCache)
assert isinstance(resp.cache.redis_client, RedisCluster)
Expand Down