Skip to content

Commit

Permalink
Fix RetryInvoker (#2553)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <daniel@flower.dev>
  • Loading branch information
panh99 and danieljanes authored Nov 1, 2023
1 parent 3718f80 commit e2583bb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
33 changes: 17 additions & 16 deletions src/py/flwr/common/retry_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class RetryInvoker:
Parameters
----------
wait_strategy: Generator[float, None, None]
wait_factory: Callable[[], Generator[float, None, None]]
A generator yielding successive wait times in seconds. If the generator
is finite, the giveup event will be triggered when the generator raises
`StopIteration`.
Expand All @@ -129,11 +129,11 @@ class RetryInvoker:
data class object detailing the invocation.
on_giveup: Optional[Callable[[RetryState], None]] (default: None)
A callable to be executed in the event that `max_tries` or `max_time` is
exceeded, `should_giveup` returns True, or `wait_strategy` generator raises
exceeded, `should_giveup` returns True, or `wait_factory()` generator raises
`StopInteration`. The parameter is a data class object detailing the
invocation.
jitter: Optional[Callable[[float], float]] (default: full_jitter)
A function of the value yielded by `wait_strategy` returning the actual time
A function of the value yielded by `wait_factory()` returning the actual time
to wait. This function helps distribute wait times stochastically to avoid
timing collisions across concurrent clients. Wait times are jittered by
default using the `full_jitter` function. To disable jittering, pass
Expand All @@ -145,20 +145,20 @@ class RetryInvoker:
Examples
--------
Initialize a `RetryInvoker` with exponential backoff and call a function:
Initialize a `RetryInvoker` with exponential backoff and invoke a function:
>>> invoker = RetryInvoker(
>>> exponential(),
>>> grpc.RpcError,
>>> max_tries=3,
>>> max_time=None,
>>> )
... exponential, # Or use `lambda: exponential(3, 2)` to pass arguments
... grpc.RpcError,
... max_tries=3,
... max_time=None,
... )
>>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
"""

def __init__(
self,
wait_strategy: Generator[float, None, None],
wait_factory: Callable[[], Generator[float, None, None]],
recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
max_tries: Optional[int],
max_time: Optional[float],
Expand All @@ -169,7 +169,7 @@ def __init__(
jitter: Optional[Callable[[float], float]] = full_jitter,
should_giveup: Optional[Callable[[Exception], bool]] = None,
) -> None:
self.wait_strategy = wait_strategy
self.wait_factory = wait_factory
self.recoverable_exceptions = recoverable_exceptions
self.max_tries = max_tries
self.max_time = max_time
Expand All @@ -183,8 +183,8 @@ def __init__(
def invoke(
self,
target: Callable[..., Any],
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Any:
"""Safely invoke the provided callable with retry mechanisms.
Expand Down Expand Up @@ -212,12 +212,12 @@ def invoke(
------
Exception
If the number of tries exceeds `max_tries`, if the total time
exceeds `max_time`, if `wait_strategy` generator raises `StopInteration`,
exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`,
or if the `should_giveup` returns True for a raised exception.
Notes
-----
The time between retries is determined by the provided `wait_strategy`
The time between retries is determined by the provided `wait_factory()`
generator and can optionally be jittered using the `jitter` function.
The recoverable exceptions that trigger a retry, as well as conditions to
stop retries, are also determined by the class's initialization parameters.
Expand All @@ -230,6 +230,7 @@ def try_call_event_handler(
handler(cast(RetryState, ref_state[0]))

try_cnt = 0
wait_generator = self.wait_factory()
start = time.time()
ref_state: List[Optional[RetryState]] = [None]

Expand Down Expand Up @@ -265,7 +266,7 @@ def giveup_check(_exception: Exception) -> bool:
raise

try:
wait_time = next(self.wait_strategy)
wait_time = next(wait_generator)
if self.jitter is not None:
wait_time = self.jitter(wait_time)
if self.max_time is not None:
Expand Down
22 changes: 14 additions & 8 deletions src/py/flwr/common/retry_invoker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_successful_invocation() -> None:
backoff_handler = Mock()
giveup_handler = Mock()
invoker = RetryInvoker(
constant(0.1),
lambda: constant(0.1),
ValueError,
max_tries=None,
max_time=None,
Expand All @@ -77,7 +77,7 @@ def test_failure() -> None:
"""Check termination when unexpected exception is raised."""
# Prepare
# `constant([0.1])` generator will raise `StopIteration` after one iteration.
invoker = RetryInvoker(constant(0.1), TypeError, None, None)
invoker = RetryInvoker(lambda: constant(0.1), TypeError, None, None)

# Execute and Assert
with pytest.raises(ValueError):
Expand All @@ -88,7 +88,11 @@ def test_failure_two_exceptions(mock_sleep: MagicMock) -> None:
"""Verify one retry on a specified iterable of exceptions."""
# Prepare
invoker = RetryInvoker(
constant(0.1), (TypeError, ValueError), max_tries=2, max_time=None, jitter=None
lambda: constant(0.1),
(TypeError, ValueError),
max_tries=2,
max_time=None,
jitter=None,
)

# Execute and Assert
Expand All @@ -101,7 +105,7 @@ def test_backoff_on_failure(mock_sleep: MagicMock) -> None:
"""Verify one retry on specified exception."""
# Prepare
# `constant([0.1])` generator will raise `StopIteration` after one iteration.
invoker = RetryInvoker(constant([0.1]), ValueError, None, None, jitter=None)
invoker = RetryInvoker(lambda: constant([0.1]), ValueError, None, None, jitter=None)

# Execute and Assert
with pytest.raises(ValueError):
Expand All @@ -114,7 +118,7 @@ def test_max_tries(mock_sleep: MagicMock) -> None:
# Prepare
# Disable `jitter` to ensure 0.1s wait time.
invoker = RetryInvoker(
constant(0.1), ValueError, max_tries=2, max_time=None, jitter=None
lambda: constant(0.1), ValueError, max_tries=2, max_time=None, jitter=None
)

# Execute and Assert
Expand All @@ -132,7 +136,9 @@ def test_max_time(mock_time: MagicMock, mock_sleep: MagicMock) -> None:
0.0,
3.0,
]
invoker = RetryInvoker(constant(2), ValueError, max_tries=None, max_time=2.5)
invoker = RetryInvoker(
lambda: constant(2), ValueError, max_tries=None, max_time=2.5
)

# Execute and Assert
with pytest.raises(ValueError):
Expand All @@ -148,7 +154,7 @@ def test_event_handlers() -> None:
backoff_handler = Mock()
giveup_handler = Mock()
invoker = RetryInvoker(
constant(0.1),
lambda: constant(0.1),
ValueError,
max_tries=2,
max_time=None,
Expand All @@ -173,7 +179,7 @@ def should_give_up(exc: Exception) -> bool:
return isinstance(exc, ValueError)

invoker = RetryInvoker(
constant(0.1), ValueError, None, None, should_giveup=should_give_up
lambda: constant(0.1), ValueError, None, None, should_giveup=should_give_up
)

# Execute and Assert
Expand Down

0 comments on commit e2583bb

Please sign in to comment.