Skip to content

Commit

Permalink
Type hint improvements (#2952)
Browse files Browse the repository at this point in the history
* Some type hints

* fixed callable[T]

* con

* more connectios

* restoring dev reqs

* Update redis/commands/search/suggestion.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/commands/core.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/commands/search/suggestion.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/commands/search/commands.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/client.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/commands/search/suggestion.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/connection.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/connection.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/connection.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/connection.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/client.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Update redis/client.py

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>

* linters

---------

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
  • Loading branch information
chayim and dvora-h authored Sep 21, 2023
1 parent 56b254e commit 2ee7c3c
Show file tree
Hide file tree
Showing 32 changed files with 289 additions and 276 deletions.
1 change: 0 additions & 1 deletion redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@


class BaseParser(ABC):

EXCEPTION_CLASSES = {
"ERR": {
"max number of clients reached": ConnectionError,
Expand Down
6 changes: 2 additions & 4 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,8 @@ async def _read_response(
]
res = self.push_handler_func(response)
if not push_request:
return await (
self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
Expand Down
1 change: 0 additions & 1 deletion redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,6 @@ def __init__(
queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
**connection_kwargs,
):

super().__init__(
connection_class=connection_class,
max_connections=max_connections,
Expand Down
118 changes: 68 additions & 50 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import time
import warnings
from itertools import chain
from typing import Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

from redis._parsers.encoders import Encoder
from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
Expand Down Expand Up @@ -49,7 +50,7 @@
class CaseInsensitiveDict(dict):
"Case insensitive dict implementation. Assumes string keys only."

def __init__(self, data):
def __init__(self, data: Dict[str, str]) -> None:
for k, v in data.items():
self[k.upper()] = v

Expand Down Expand Up @@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"""

@classmethod
def from_url(cls, url, **kwargs):
def from_url(cls, url: str, **kwargs) -> None:
"""
Return a Redis client object configured from the given URL
Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
):
) -> None:
"""
Initialize a new Redis client.
To specify a retry policy for specific errors, first set
Expand Down Expand Up @@ -309,14 +310,14 @@ def __init__(
else:
self.response_callbacks.update(_RedisCallbacksRESP2)

def __repr__(self):
def __repr__(self) -> str:
return f"{type(self).__name__}<{repr(self.connection_pool)}>"

def get_encoder(self):
def get_encoder(self) -> "Encoder":
"""Get the connection pool's encoder"""
return self.connection_pool.get_encoder()

def get_connection_kwargs(self):
def get_connection_kwargs(self) -> Dict:
"""Get the connection's key-word arguments"""
return self.connection_pool.connection_kwargs

Expand All @@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None:
self.get_connection_kwargs().update({"retry": retry})
self.connection_pool.set_retry(retry)

def set_response_callback(self, command, callback):
def set_response_callback(self, command: str, callback: Callable) -> None:
"""Set a custom Response Callback"""
self.response_callbacks[command] = callback

def load_external_module(self, funcname, func):
def load_external_module(self, funcname, func) -> None:
"""
This function can be used to add externally defined redis modules,
and their namespaces to the redis client.
Expand All @@ -354,7 +355,7 @@ def load_external_module(self, funcname, func):
"""
setattr(self, funcname, func)

def pipeline(self, transaction=True, shard_hint=None):
def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline":
"""
Return a new pipeline object that can queue multiple commands for
later execution. ``transaction`` indicates whether all commands
Expand All @@ -366,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None):
self.connection_pool, self.response_callbacks, transaction, shard_hint
)

def transaction(self, func, *watches, **kwargs):
def transaction(
self, func: Callable[["Pipeline"], None], *watches, **kwargs
) -> None:
"""
Convenience method for executing the callable `func` as a transaction
while watching all keys specified in `watches`. The 'func' callable
Expand All @@ -390,13 +393,13 @@ def transaction(self, func, *watches, **kwargs):

def lock(
self,
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
name: str,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Union[None, Any] = None,
thread_local: bool = True,
):
"""
Return a new Lock object using key ``name`` that mimics
Expand Down Expand Up @@ -648,9 +651,9 @@ def __init__(
self,
connection_pool,
shard_hint=None,
ignore_subscribe_messages=False,
encoder=None,
push_handler_func=None,
ignore_subscribe_messages: bool = False,
encoder: Optional["Encoder"] = None,
push_handler_func: Union[None, Callable[[str], None]] = None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
Expand All @@ -672,13 +675,13 @@ def __init__(
_set_info_logger()
self.reset()

def __enter__(self):
def __enter__(self) -> "PubSub":
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.reset()

def __del__(self):
def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
# subscriptions, close the connection manually before
Expand All @@ -687,7 +690,7 @@ def __del__(self):
except Exception:
pass

def reset(self):
def reset(self) -> None:
if self.connection:
self.connection.disconnect()
self.connection._deregister_connect_callback(self.on_connect)
Expand All @@ -702,10 +705,10 @@ def reset(self):
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()

def close(self):
def close(self) -> None:
self.reset()

def on_connect(self, connection):
def on_connect(self, connection) -> None:
"Re-subscribe to any channels and patterns previously subscribed to"
# NOTE: for python3, we can't pass bytestrings as keyword arguments
# so we need to decode channel/pattern names back to unicode strings
Expand All @@ -731,7 +734,7 @@ def on_connect(self, connection):
self.ssubscribe(**shard_channels)

@property
def subscribed(self):
def subscribed(self) -> bool:
"""Indicates if there are subscriptions to any channels or patterns"""
return self.subscribed_event.is_set()

Expand All @@ -757,7 +760,7 @@ def execute_command(self, *args):
self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)

def clean_health_check_responses(self):
def clean_health_check_responses(self) -> None:
"""
If any health check responses are present, clean them
"""
Expand All @@ -775,7 +778,7 @@ def clean_health_check_responses(self):
)
ttl -= 1

def _disconnect_raise_connect(self, conn, error):
def _disconnect_raise_connect(self, conn, error) -> None:
"""
Close the connection and raise an exception
if retry_on_timeout is not set or the error
Expand Down Expand Up @@ -826,7 +829,7 @@ def try_read():
return None
return response

def is_health_check_response(self, response):
def is_health_check_response(self, response) -> bool:
"""
Check if the response is a health check response.
If there are no subscriptions redis responds to PING command with a
Expand All @@ -837,7 +840,7 @@ def is_health_check_response(self, response):
self.health_check_response_b, # If there wasn't
]

def check_health(self):
def check_health(self) -> None:
conn = self.connection
if conn is None:
raise RuntimeError(
Expand All @@ -849,7 +852,7 @@ def check_health(self):
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
self.health_check_response_counter += 1

def _normalize_keys(self, data):
def _normalize_keys(self, data) -> Dict:
"""
normalize channel/pattern names to be either bytes or strings
based on whether responses are automatically decoded. this saves us
Expand Down Expand Up @@ -983,7 +986,9 @@ def listen(self):
if response is not None:
yield response

def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available, otherwise None.
Expand Down Expand Up @@ -1012,7 +1017,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):

get_sharded_message = get_message

def ping(self, message=None):
def ping(self, message: Union[str, None] = None) -> bool:
"""
Ping the Redis server
"""
Expand Down Expand Up @@ -1093,7 +1098,12 @@ def handle_message(self, response, ignore_subscribe_messages=False):

return message

def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
def run_in_thread(
self,
sleep_time: int = 0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
) -> "PubSubWorkerThread":
for channel, handler in self.channels.items():
if handler is None:
raise PubSubError(f"Channel: '{channel}' has no handler registered")
Expand All @@ -1114,15 +1124,23 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):


class PubSubWorkerThread(threading.Thread):
def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None):
def __init__(
self,
pubsub,
sleep_time: float,
daemon: bool = False,
exception_handler: Union[
Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None
] = None,
):
super().__init__()
self.daemon = daemon
self.pubsub = pubsub
self.sleep_time = sleep_time
self.exception_handler = exception_handler
self._running = threading.Event()

def run(self):
def run(self) -> None:
if self._running.is_set():
return
self._running.set()
Expand All @@ -1137,7 +1155,7 @@ def run(self):
self.exception_handler(e, pubsub, self)
pubsub.close()

def stop(self):
def stop(self) -> None:
# trip the flag so the run loop exits. the run loop will
# close the pubsub connection, which disconnects the socket
# and returns the connection to the pool.
Expand Down Expand Up @@ -1175,7 +1193,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint)
self.watching = False
self.reset()

def __enter__(self):
def __enter__(self) -> "Pipeline":
return self

def __exit__(self, exc_type, exc_value, traceback):
Expand All @@ -1187,14 +1205,14 @@ def __del__(self):
except Exception:
pass

def __len__(self):
def __len__(self) -> int:
return len(self.command_stack)

def __bool__(self):
def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True

def reset(self):
def reset(self) -> None:
self.command_stack = []
self.scripts = set()
# make sure to reset the connection state in the event that we were
Expand All @@ -1217,11 +1235,11 @@ def reset(self):
self.connection_pool.release(self.connection)
self.connection = None

def close(self):
def close(self) -> None:
"""Close the pipeline"""
self.reset()

def multi(self):
def multi(self) -> None:
"""
Start a transactional block of the pipeline after WATCH commands
are issued. End the transactional block with `execute`.
Expand All @@ -1239,7 +1257,7 @@ def execute_command(self, *args, **kwargs):
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)

def _disconnect_reset_raise(self, conn, error):
def _disconnect_reset_raise(self, conn, error) -> None:
"""
Close the connection, reset watching state and
raise an exception if we were watching,
Expand Down Expand Up @@ -1282,7 +1300,7 @@ def immediate_execute_command(self, *args, **options):
lambda error: self._disconnect_reset_raise(conn, error),
)

def pipeline_execute_command(self, *args, **options):
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called
Expand All @@ -1297,7 +1315,7 @@ def pipeline_execute_command(self, *args, **options):
self.command_stack.append((args, options))
return self

def _execute_transaction(self, connection, commands, raise_on_error):
def _execute_transaction(self, connection, commands, raise_on_error) -> List:
cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})])
all_cmds = connection.pack_commands(
[args for args, options in cmds if EMPTY_RESPONSE not in options]
Expand Down Expand Up @@ -1415,7 +1433,7 @@ def load_scripts(self):
if not exist:
s.sha = immediate("SCRIPT LOAD", s.script)

def _disconnect_raise_reset(self, conn, error):
def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None:
"""
Close the connection, raise an exception if we were watching,
and raise an exception if TimeoutError is not part of retry_on_error,
Expand Down Expand Up @@ -1477,6 +1495,6 @@ def watch(self, *names):
raise RedisError("Cannot issue a WATCH after a MULTI")
return self.execute_command("WATCH", *names)

def unwatch(self):
def unwatch(self) -> bool:
"""Unwatches all previously specified keys"""
return self.watching and self.execute_command("UNWATCH") or True
Loading

0 comments on commit 2ee7c3c

Please sign in to comment.