Skip to content

Commit

Permalink
Cancellable (#31)
Browse files Browse the repository at this point in the history
* Add the cancelling() ctxtmgr for futures and tasks

* Improve cancelling tests

* export cancelling to top

* make cancel generic

* Add CAwaitable (cancellable) for typing

* add eagec_ctx()

* reverse version macros

* document cancelling() and eager_ctx()

* fix typing for 38 and earlier

* use explicit sys.version checks for mypy

* explicit version checks

* fix Cancellable to be compatible with old python versions
  • Loading branch information
kristjanvalur authored Mar 10, 2024
1 parent 0062ed0 commit 2f8715c
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 35 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ Decorating your function makes sense if you __always__ intend
To _await_ its result at some later point. Otherwise, just apply it at the point
of invocation in each such case.

It may be prudent to ensure that the result of `eager()` does not continue running
if it will never be awaited, such as in the case of an error. You can use the `cancelling()`
context manager for this:

```python
with cancelling(eager(my_method())) as v:
await some_method_which_may_raise()
await v
```

As a convenience, `eager_ctx()` will perform the above:

```python
with eager_ctx(my_method()) as v:
await some_method_which_may_raise()
await v
```

### `coro_eager()`, `func_eager()`

`coro_eager()` is the magic coroutine wrapper providing the __eager__ behaviour:
Expand Down Expand Up @@ -343,6 +361,18 @@ is provided for completeness.
This helper function turns a coroutine function into an iterator. It is primarily
intended to be used by the [`awaitmethod_iter()`](#awaitmethod_iter) function decorator.

### `cancelling()`

This context manager automatically calls the `cancel()`method on its target when the scope
exits. This is convenient to make sure that a task is not left running if it never to
be awaited:

```python
with cancelling(asyncio.Task(foo())) as t:
function_which_may_fail()
return await t
```

### Monitors and Generators

#### Monitor
Expand Down
1 change: 1 addition & 0 deletions src/asynkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .loop.eventloop import *
from .monitor import *
from .scheduling import *
from .tools import cancelling
18 changes: 9 additions & 9 deletions src/asynkit/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

"""Compatibility routines for earlier asyncio versions"""

# 3.8 or earlier
PYTHON_38 = sys.version_info[:2] <= (3, 8)
PYTHON_39 = sys.version_info[:2] <= (3, 9)
# Pyton version checks
PY_39 = sys.version_info >= (3, 9)
PY_311 = sys.version_info >= (3, 11)

T = TypeVar("T")

Expand All @@ -27,7 +27,10 @@

# create_task() got the name argument in 3.8

if PYTHON_38: # pragma: no cover
if sys.version_info >= (3, 9): # pragma: no cover
create_task = asyncio.create_task

else: # pragma: no cover

def create_task(
coro: Coroutine[Any, Any, T],
Expand All @@ -36,12 +39,9 @@ def create_task(
) -> _TaskAny:
return asyncio.create_task(coro)

else: # pragma: no cover
create_task = asyncio.create_task # type: ignore


# loop.call_soon got the context argument in 3.9.10 and 3.10.2
if PYTHON_39: # pragma: no cover
if sys.version_info >= (3, 9): # pragma: no cover

def call_soon(
loop: AbstractEventLoop,
Expand All @@ -62,7 +62,7 @@ def call_soon(
return loop.call_soon(callback, *args)


if not PYTHON_39: # pragma: no cover
if sys.version_info >= (3, 10): # pragma: no cover
from asyncio.mixins import _LoopBoundMixin # type: ignore[import]

LoopBoundMixin = _LoopBoundMixin
Expand Down
55 changes: 41 additions & 14 deletions src/asynkit/coroutine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import functools
import inspect
import sys
import types
from asyncio import Future
from contextvars import Context, copy_context
from types import FrameType
from typing import (
Expand All @@ -10,6 +12,7 @@
AsyncIterable,
Awaitable,
Callable,
ContextManager,
Coroutine,
Generator,
Iterator,
Expand All @@ -22,9 +25,9 @@
overload,
)

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Protocol, TypeAlias

from .tools import create_task
from .tools import Cancellable, cancelling, create_task

__all__ = [
"CoroStart",
Expand All @@ -34,6 +37,7 @@
"coro_eager",
"func_eager",
"eager",
"eager_ctx",
"coro_get_frame",
"coro_is_new",
"coro_is_suspended",
Expand All @@ -53,6 +57,18 @@
T_co = TypeVar("T_co", covariant=True)
Suspendable = Union[Coroutine, Generator, AsyncGenerator]


class CAwaitable(Awaitable[T_co], Cancellable, Protocol):
pass


# must use explicit sys.version_info check because of mypy
if sys.version_info >= (3, 9): # pragma: no cover
Future_Type: TypeAlias = Future
else: # pragma: no cover
Future_Type: TypeAlias = CAwaitable


"""
Tools and utilities for advanced management of coroutines
"""
Expand Down Expand Up @@ -334,7 +350,7 @@ async def as_coroutine(self) -> T_co:
"""
return await self

def as_future(self) -> Awaitable[T_co]:
def as_future(self) -> Future_Type[T_co]:
"""
if `done()` convert the result of the coroutine into a `Future`
and return it. Otherwise raise a `RuntimeError`
Expand All @@ -353,7 +369,7 @@ def as_future(self) -> Awaitable[T_co]:

def as_awaitable(self) -> Awaitable[T_co]:
"""
If `done()`, return `as_future()`, else `as_coroutine()`.
If `done()`, return `as_future()`, else return self.
This is a convenience function for use when the instance
is to be passed directly to methods such as `asyncio.gather()`.
In such cases, we want to avoid a `done()` instance to cause
Expand Down Expand Up @@ -381,8 +397,8 @@ async def coro_await(
def coro_eager(
coro: Coroutine[Any, Any, T],
*,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None,
) -> Awaitable[T]:
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None,
) -> CAwaitable[T]:
"""
Make the coroutine "eager":
Start the coroutine. If it blocks, create a task to continue
Expand All @@ -405,15 +421,15 @@ def coro_eager(
def func_eager(
func: Callable[P, Coroutine[Any, Any, T]],
*,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None,
) -> Callable[P, Awaitable[T]]:
"""
Decorator to automatically apply the `coro_eager` to the
coroutine generated by invoking the given async function
"""

@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> CAwaitable[T]:
return coro_eager(func(*args, **kwargs), task_factory=task_factory)

return wrapper
Expand All @@ -423,25 +439,25 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
def eager(
arg: Coroutine[Any, Any, T],
*,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None,
) -> Awaitable[T]:
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None,
) -> CAwaitable[T]:
...


@overload
def eager(
arg: Callable[P, Coroutine[Any, Any, T]],
*,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None,
) -> Callable[P, Awaitable[T]]:
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None,
) -> Callable[P, CAwaitable[T]]:
...


def eager(
arg: Union[Coroutine[Any, Any, T], Callable[P, Coroutine[Any, Any, T]]],
*,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], Awaitable[T]]] = None,
) -> Union[Awaitable[T], Callable[P, Awaitable[T]]]:
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None,
) -> Union[CAwaitable[T], Callable[P, CAwaitable[T]]]:
"""
Convenience function invoking either `coro_eager` or `func_eager`
to either decorate an async function or convert a coroutine returned by
Expand All @@ -455,6 +471,17 @@ def eager(
raise TypeError("need coroutine or function")


def eager_ctx(
coro: Coroutine[Any, Any, T],
*,
task_factory: Optional[Callable[[Coroutine[Any, Any, T]], CAwaitable[T]]] = None,
msg: Optional[str] = None,
) -> ContextManager[CAwaitable[T]]:
"""Create an eager task and return a context manager that will cancel it on exit."""
e = coro_eager(coro, task_factory=task_factory)
return cancelling(e, msg=msg)


def coro_iter(coro: Coroutine[Any, Any, T]) -> Generator[Any, Any, T]:
"""
Helper to turn a coroutine into an iterator, which can then
Expand Down
12 changes: 8 additions & 4 deletions src/asynkit/experimental/priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@

from typing_extensions import Literal

from asynkit.compat import PYTHON_38, FutureBool, LockHelper, ReferenceTypeTaskAny
from asynkit.compat import (
FutureBool,
LockHelper,
ReferenceTypeTaskAny,
)
from asynkit.loop.default import task_from_handle
from asynkit.loop.schedulingloop import AbstractSchedulingLoop
from asynkit.loop.types import TaskAny
Expand Down Expand Up @@ -614,7 +618,7 @@ def reschedule_all(self) -> None:


class EventLoopLike(Protocol): # pragma: no cover
if not PYTHON_38:
if sys.version_info >= (3, 9): # pragma: no cover

def call_soon(
self,
Expand All @@ -626,7 +630,7 @@ def call_soon(

else:

def call_soon( # type: ignore[misc]
def call_soon(
self,
callback: Callable[..., Any],
*args: Any,
Expand Down Expand Up @@ -680,7 +684,7 @@ def call_pos(
This is effectively the same as calling
`call_soon()`, `queue_remove()` and `queue_insert_pos()` in turn.
"""
if not PYTHON_38: # pragma: no cover
if sys.version_info >= (3, 9): # pragma: no cover
handle = self.call_soon(callback, *args, context=context)
else: # pragma: no cover
handle = self.call_soon(callback, *args)
Expand Down
42 changes: 35 additions & 7 deletions src/asynkit/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import heapq
import sys
from typing import (
Expand All @@ -14,13 +15,11 @@
Iterable,
Iterator,
Optional,
Protocol,
Tuple,
TypeVar,
)

# 3.8 or earlier
PYTHON_38 = sys.version_info[:2] <= (3, 8)
from typing_extensions import Protocol

T = TypeVar("T")

Expand All @@ -40,7 +39,9 @@ def __lt__(self: T, other: T) -> bool:
else:
_TaskAny = asyncio.Task

if PYTHON_38: # pragma: no cover
if sys.version_info >= (3, 9): # pragma: no cover
create_task = asyncio.create_task
else: # pragma: no cover

def create_task(
coro: Coroutine[Any, Any, T],
Expand All @@ -49,9 +50,6 @@ def create_task(
) -> _TaskAny:
return asyncio.create_task(coro)

else: # pragma: no cover
create_task = asyncio.create_task # type: ignore


def deque_pop(d: Deque[T], pos: int = -1) -> T:
"""
Expand Down Expand Up @@ -298,3 +296,33 @@ def __lt__(self, other: PriEntry[P, T]) -> bool:
return (self.priority < other.priority) or (
not (other.priority < self.priority) and self.sequence < other.sequence
)


class Cancellable(Protocol):
if sys.version_info >= (3, 9): # pragma: no cover

def cancel(self, msg: Optional[str] = None) -> Any:
...

else: # pragma: no cover

def cancel(self) -> Any:
...

def cancelled(self) -> bool:
...


CA = TypeVar("CA", bound=Cancellable)


@contextlib.contextmanager
def cancelling(target: CA, msg: Optional[str] = None) -> Generator[CA, None, None]:
"""Ensure that the target is cancelled"""
try:
yield target
finally: # pragma: no cover
if sys.version_info >= (3, 9):
target.cancel(msg)
else:
target.cancel()
Loading

0 comments on commit 2f8715c

Please sign in to comment.