Skip to content

Commit

Permalink
Copy over strategies as async
Browse files Browse the repository at this point in the history
  • Loading branch information
hasier committed Jun 3, 2024
1 parent 3c5f788 commit 927965d
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions tenacity/asyncio/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import inspect
import typing

from tenacity import _utils
from tenacity import retry_base
from tenacity import retry_if_exception as _retry_if_exception
from tenacity import retry_if_result as _retry_if_result

if typing.TYPE_CHECKING:
from tenacity import RetryCallState
Expand Down Expand Up @@ -54,35 +51,48 @@ def __ror__( # type: ignore[misc,override]
return retry_any(other, self)


class async_predicate_mixin:
async def __call__(self, retry_state: "RetryCallState") -> bool:
result = super().__call__(retry_state) # type: ignore[misc]
if inspect.isawaitable(result):
result = await result
return typing.cast(bool, result)


RetryBaseT = typing.Union[
async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]
]


class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc]
class retry_if_exception(async_retry_base):
"""Retry strategy that retries if an exception verifies a predicate."""

def __init__(
self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]
) -> None:
super().__init__(predicate) # type: ignore[arg-type]
self.predicate = predicate

async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
if retry_state.outcome is None:
raise RuntimeError("__call__() called before outcome was set")

if retry_state.outcome.failed:
exception = retry_state.outcome.exception()
if exception is None:
raise RuntimeError("outcome failed but the exception is None")
return await self.predicate(exception)
else:
return False

class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc]

class retry_if_result(async_retry_base):
"""Retries if the result verifies a predicate."""

def __init__(
self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]
) -> None:
super().__init__(predicate) # type: ignore[arg-type]
self.predicate = predicate

async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
if retry_state.outcome is None:
raise RuntimeError("__call__() called before outcome was set")

if not retry_state.outcome.failed:
return await self.predicate(retry_state.outcome.result())
else:
return False


class retry_any(async_retry_base):
Expand Down

0 comments on commit 927965d

Please sign in to comment.