From 2f8715c252a8ea7ff4494bf58c557650bbd2d6a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 10 Mar 2024 17:00:25 +0000 Subject: [PATCH] Cancellable (#31) * Add the cancelling() ctxtmgr for futures and tasks * Improve cancelling tests * export cancelling to top * make cancel generic * Add CAwaitable (cancellable) for typing * add eagec_ctx() * reverse version macros * document cancelling() and eager_ctx() * fix typing for 38 and earlier * use explicit sys.version checks for mypy * explicit version checks * fix Cancellable to be compatible with old python versions --- README.md | 30 +++++++++++++ src/asynkit/__init__.py | 1 + src/asynkit/compat.py | 18 ++++---- src/asynkit/coroutine.py | 55 +++++++++++++++++------- src/asynkit/experimental/priority.py | 12 ++++-- src/asynkit/tools.py | 42 +++++++++++++++---- tests/test_coro.py | 24 ++++++++++- tests/test_tools.py | 63 ++++++++++++++++++++++++++++ 8 files changed, 210 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 0011125..c240dba 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,24 @@ Decorating your function makes sense if you __always__ intend To _await_ its result at some later point. Otherwise, just apply it at the point of invocation in each such case. +It may be prudent to ensure that the result of `eager()` does not continue running +if it will never be awaited, such as in the case of an error. You can use the `cancelling()` +context manager for this: + +```python +with cancelling(eager(my_method())) as v: + await some_method_which_may_raise() + await v +``` + +As a convenience, `eager_ctx()` will perform the above: + +```python +with eager_ctx(my_method()) as v: + await some_method_which_may_raise() + await v +``` + ### `coro_eager()`, `func_eager()` `coro_eager()` is the magic coroutine wrapper providing the __eager__ behaviour: @@ -343,6 +361,18 @@ is provided for completeness. This helper function turns a coroutine function into an iterator. It is primarily intended to be used by the [`awaitmethod_iter()`](#awaitmethod_iter) function decorator. +### `cancelling()` + +This context manager automatically calls the `cancel()`method on its target when the scope +exits. This is convenient to make sure that a task is not left running if it never to +be awaited: + +```python +with cancelling(asyncio.Task(foo())) as t: + function_which_may_fail() + return await t +``` + ### Monitors and Generators #### Monitor diff --git a/src/asynkit/__init__.py b/src/asynkit/__init__.py index ee5702d..0a1fb21 100644 --- a/src/asynkit/__init__.py +++ b/src/asynkit/__init__.py @@ -3,3 +3,4 @@ from .loop.eventloop import * from .monitor import * from .scheduling import * +from .tools import cancelling diff --git a/src/asynkit/compat.py b/src/asynkit/compat.py index 07ddcdc..e0485fc 100644 --- a/src/asynkit/compat.py +++ b/src/asynkit/compat.py @@ -8,9 +8,9 @@ """Compatibility routines for earlier asyncio versions""" -# 3.8 or earlier -PYTHON_38 = sys.version_info[:2] <= (3, 8) -PYTHON_39 = sys.version_info[:2] <= (3, 9) +# Pyton version checks +PY_39 = sys.version_info >= (3, 9) +PY_311 = sys.version_info >= (3, 11) T = TypeVar("T") @@ -27,7 +27,10 @@ # create_task() got the name argument in 3.8 -if PYTHON_38: # pragma: no cover +if sys.version_info >= (3, 9): # pragma: no cover + create_task = asyncio.create_task + +else: # pragma: no cover def create_task( coro: Coroutine[Any, Any, T], @@ -36,12 +39,9 @@ def create_task( ) -> _TaskAny: return asyncio.create_task(coro) -else: # pragma: no cover - create_task = asyncio.create_task # type: ignore - # loop.call_soon got the context argument in 3.9.10 and 3.10.2 -if PYTHON_39: # pragma: no cover +if sys.version_info >= (3, 9): # pragma: no cover def call_soon( loop: AbstractEventLoop, @@ -62,7 +62,7 @@ def call_soon( return loop.call_soon(callback, *args) -if not PYTHON_39: # pragma: no cover +if sys.version_info >= (3, 10): # pragma: no cover from asyncio.mixins import _LoopBoundMixin # type: ignore[import] LoopBoundMixin = _LoopBoundMixin diff --git a/src/asynkit/coroutine.py b/src/asynkit/coroutine.py index c4e5f04..0c24762 100644 --- a/src/asynkit/coroutine.py +++ b/src/asynkit/coroutine.py @@ -1,7 +1,9 @@ import asyncio import functools import inspect +import sys import types +from asyncio import Future from contextvars import Context, copy_context from types import FrameType from typing import ( @@ -10,6 +12,7 @@ AsyncIterable, Awaitable, Callable, + ContextManager, Coroutine, Generator, Iterator, @@ -22,9 +25,9 @@ overload, ) -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Protocol, TypeAlias -from .tools import create_task +from .tools import Cancellable, cancelling, create_task __all__ = [ "CoroStart", @@ -34,6 +37,7 @@ "coro_eager", "func_eager", "eager", + "eager_ctx", "coro_get_frame", "coro_is_new", "coro_is_suspended", @@ -53,6 +57,18 @@ T_co = TypeVar("T_co", covariant=True) Suspendable = Union[Coroutine, Generator, AsyncGenerator] + +class CAwaitable(Awaitable[T_co], Cancellable, Protocol): + pass + + +# must use explicit sys.version_info check because of mypy +if sys.version_info >= (3, 9): # pragma: no cover + Future_Type: TypeAlias = Future +else: # pragma: no cover + Future_Type: TypeAlias = CAwaitable + + """ Tools and utilities for advanced management of coroutines """ @@ -334,7 +350,7 @@ async def as_coroutine(self) -> T_co: """ return await self - def as_future(self) -> Awaitable[T_co]: + def as_future(self) -> Future_Type[T_co]: """ if `done()` convert the result of the coroutine into a `Future` and return it. Otherwise raise a `RuntimeError` @@ -353,7 +369,7 @@ def as_future(self) -> Awaitable[T_co]: def as_awaitable(self) -> Awaitable[T_co]: """ - If `done()`, return `as_future()`, else `as_coroutine()`. + If `done()`, return `as_future()`, else return self. This is a convenience function for use when the instance is to be passed directly to methods such as `asyncio.gather()`. In such cases, we want to avoid a `done()` instance to cause @@ -381,8 +397,8 @@ async def coro_await( def coro_eager( coro: Coroutine[Any, Any, T], *, - task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None, -) -> Awaitable[T]: + task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None, +) -> CAwaitable[T]: """ Make the coroutine "eager": Start the coroutine. If it blocks, create a task to continue @@ -405,7 +421,7 @@ def coro_eager( def func_eager( func: Callable[P, Coroutine[Any, Any, T]], *, - task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None, + task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None, ) -> Callable[P, Awaitable[T]]: """ Decorator to automatically apply the `coro_eager` to the @@ -413,7 +429,7 @@ def func_eager( """ @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> CAwaitable[T]: return coro_eager(func(*args, **kwargs), task_factory=task_factory) return wrapper @@ -423,8 +439,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: def eager( arg: Coroutine[Any, Any, T], *, - task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None, -) -> Awaitable[T]: + task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None, +) -> CAwaitable[T]: ... @@ -432,16 +448,16 @@ def eager( def eager( arg: Callable[P, Coroutine[Any, Any, T]], *, - task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None, -) -> Callable[P, Awaitable[T]]: + task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None, +) -> Callable[P, CAwaitable[T]]: ... def eager( arg: Union[Coroutine[Any, Any, T], Callable[P, Coroutine[Any, Any, T]]], *, - task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None, -) -> Union[Awaitable[T], Callable[P, Awaitable[T]]]: + task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None, +) -> Union[CAwaitable[T], Callable[P, CAwaitable[T]]]: """ Convenience function invoking either `coro_eager` or `func_eager` to either decorate an async function or convert a coroutine returned by @@ -455,6 +471,17 @@ def eager( raise TypeError("need coroutine or function") +def eager_ctx( + coro: Coroutine[Any, Any, T], + *, + task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None, + msg: Optional[str] = None, +) -> ContextManager[CAwaitable[T]]: + """Create an eager task and return a context manager that will cancel it on exit.""" + e = coro_eager(coro, task_factory=task_factory) + return cancelling(e, msg=msg) + + def coro_iter(coro: Coroutine[Any, Any, T]) -> Generator[Any, Any, T]: """ Helper to turn a coroutine into an iterator, which can then diff --git a/src/asynkit/experimental/priority.py b/src/asynkit/experimental/priority.py index 1759fc3..079a8ac 100644 --- a/src/asynkit/experimental/priority.py +++ b/src/asynkit/experimental/priority.py @@ -26,7 +26,11 @@ from typing_extensions import Literal -from asynkit.compat import PYTHON_38, FutureBool, LockHelper, ReferenceTypeTaskAny +from asynkit.compat import ( + FutureBool, + LockHelper, + ReferenceTypeTaskAny, +) from asynkit.loop.default import task_from_handle from asynkit.loop.schedulingloop import AbstractSchedulingLoop from asynkit.loop.types import TaskAny @@ -614,7 +618,7 @@ def reschedule_all(self) -> None: class EventLoopLike(Protocol): # pragma: no cover - if not PYTHON_38: + if sys.version_info >= (3, 9): # pragma: no cover def call_soon( self, @@ -626,7 +630,7 @@ def call_soon( else: - def call_soon( # type: ignore[misc] + def call_soon( self, callback: Callable[..., Any], *args: Any, @@ -680,7 +684,7 @@ def call_pos( This is effectively the same as calling `call_soon()`, `queue_remove()` and `queue_insert_pos()` in turn. """ - if not PYTHON_38: # pragma: no cover + if sys.version_info >= (3, 9): # pragma: no cover handle = self.call_soon(callback, *args, context=context) else: # pragma: no cover handle = self.call_soon(callback, *args) diff --git a/src/asynkit/tools.py b/src/asynkit/tools.py index 838361d..f6d7082 100644 --- a/src/asynkit/tools.py +++ b/src/asynkit/tools.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import heapq import sys from typing import ( @@ -14,13 +15,11 @@ Iterable, Iterator, Optional, - Protocol, Tuple, TypeVar, ) -# 3.8 or earlier -PYTHON_38 = sys.version_info[:2] <= (3, 8) +from typing_extensions import Protocol T = TypeVar("T") @@ -40,7 +39,9 @@ def __lt__(self: T, other: T) -> bool: else: _TaskAny = asyncio.Task -if PYTHON_38: # pragma: no cover +if sys.version_info >= (3, 9): # pragma: no cover + create_task = asyncio.create_task +else: # pragma: no cover def create_task( coro: Coroutine[Any, Any, T], @@ -49,9 +50,6 @@ def create_task( ) -> _TaskAny: return asyncio.create_task(coro) -else: # pragma: no cover - create_task = asyncio.create_task # type: ignore - def deque_pop(d: Deque[T], pos: int = -1) -> T: """ @@ -298,3 +296,33 @@ def __lt__(self, other: PriEntry[P, T]) -> bool: return (self.priority < other.priority) or ( not (other.priority < self.priority) and self.sequence < other.sequence ) + + +class Cancellable(Protocol): + if sys.version_info >= (3, 9): # pragma: no cover + + def cancel(self, msg: Optional[str] = None) -> Any: + ... + + else: # pragma: no cover + + def cancel(self) -> Any: + ... + + def cancelled(self) -> bool: + ... + + +CA = TypeVar("CA", bound=Cancellable) + + +@contextlib.contextmanager +def cancelling(target: CA, msg: Optional[str] = None) -> Generator[CA, None, None]: + """Ensure that the target is cancelled""" + try: + yield target + finally: # pragma: no cover + if sys.version_info >= (3, 9): + target.cancel(msg) + else: + target.cancel() diff --git a/tests/test_coro.py b/tests/test_coro.py index 5b03c83..b210267 100644 --- a/tests/test_coro.py +++ b/tests/test_coro.py @@ -3,7 +3,7 @@ import types from contextlib import asynccontextmanager from contextvars import ContextVar, copy_context -from typing import Any +from typing import Any, List from unittest.mock import Mock import pytest @@ -165,6 +165,28 @@ def test_eager_invalid(self, block): with pytest.raises(TypeError): asynkit.eager(self) + async def test_eager_ctx(self, block): + log = [] + coro, expect = self.get_coro1(block) + with asynkit.eager_ctx(coro(log)) as c: + log.append("a") + await c + + assert log == expect + + async def test_eager_ctx_noawait(self, block: bool) -> None: + log: List[Any] = [] + coro, expect = self.get_coro1(block) + with asynkit.eager_ctx(coro(log)) as c: + log.append("a") + + if block: + with pytest.raises(asyncio.CancelledError): + await c + assert c.cancelled() + else: + assert log == expect + @pytest.mark.parametrize("block", [True, False], ids=["block", "noblock"]) class TestCoroStart: diff --git a/tests/test_tools.py b/tests/test_tools.py index f1b2593..17d2da8 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,3 +1,4 @@ +import asyncio import random from collections import deque from contextlib import closing @@ -5,6 +6,7 @@ import pytest import asynkit.tools +from asynkit.compat import PY_39, PY_311 from asynkit.tools import PriorityQueue from .conftest import SchedulingEventLoopPolicy @@ -343,3 +345,64 @@ def test_refresh(self): q1.refresh() assert list(q1.ordereditems()) == list(q2.ordereditems()) + + +class TestCancelling: + async def test_future(self): + f = asyncio.Future() + assert not f.cancelled() + with asynkit.tools.cancelling(f, "hello") as c: + assert not c.cancelled() + assert f.cancelled() + with pytest.raises(asyncio.CancelledError) as e: + await f + if PY_39: + assert e.value.args == ("hello",) + + f = asyncio.Future() + with pytest.raises(ValueError): + with asynkit.cancelling(f) as c: + assert not c.cancelled() + raise ValueError + assert f.cancelled() + + # it is ok to exit cancelling block with a finished future + f = asyncio.Future() + with asynkit.tools.cancelling(f) as c: + assert not c.cancelled() + f.set_result(None) + assert f.result() is None + assert not f.cancelled() + + async def test_task(self): + async def coro(): + await asyncio.sleep(0.1) + + # task is cancelled if cancelling block is exited + # without awaiting the task + t = asyncio.create_task(coro()) + assert not t.cancelled() + with asynkit.tools.cancelling(t, "hello") as c: + assert not c.cancelled() + with pytest.raises(asyncio.CancelledError) as e: + await t + if PY_311: + assert e.match("hello") + assert t.cancelled() + + # task is cancelled if an exception is raised + t = asyncio.create_task(coro()) + with pytest.raises(ValueError): + with asynkit.cancelling(t) as c: + assert not c.cancelled() + raise ValueError + with pytest.raises(asyncio.CancelledError): + await t + assert t.cancelled() + + # it is ok to exit cancelling block with a finished task + t = asyncio.create_task(coro()) + with asynkit.tools.cancelling(t) as c: + assert not c.cancelled() + await t + assert not t.cancelled()