From af601ac29615f1ce61872a3c76c373f2c6e10a0b Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 3 Jul 2023 16:53:58 +0200 Subject: [PATCH 01/21] lots of typing improvements --- trio/_abc.py | 28 ++++---- trio/_core/_mock_clock.py | 26 +++---- trio/_core/_parking_lot.py | 36 ++++++---- trio/_core/_run.py | 132 +++++++++++++++++++--------------- trio/_highlevel_generic.py | 23 +++--- trio/_highlevel_socket.py | 21 +++--- trio/_socket.py | 43 ++++++----- trio/_sync.py | 132 ++++++++++++++++++---------------- trio/_tests/verify_types.json | 126 +++++--------------------------- 9 files changed, 266 insertions(+), 301 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index c085c82b89..403e07a1de 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABCMeta, abstractmethod from typing import Generic, TypeVar import trio @@ -11,7 +13,7 @@ class Clock(metaclass=ABCMeta): __slots__ = () @abstractmethod - def start_clock(self): + def start_clock(self) -> None: """Do any setup this clock might need. Called at the beginning of the run. @@ -19,7 +21,7 @@ def start_clock(self): """ @abstractmethod - def current_time(self): + def current_time(self) -> float: """Return the current time, according to this clock. This is used to implement functions like :func:`trio.current_time` and @@ -31,7 +33,7 @@ def current_time(self): """ @abstractmethod - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: """Compute the real time until the given deadline. This is called before we enter a system-specific wait function like @@ -224,7 +226,7 @@ class AsyncResource(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def aclose(self): + async def aclose(self) -> None: """Close this resource, possibly blocking. IMPORTANT: This method may block in order to perform a "graceful" @@ -252,10 +254,10 @@ async def aclose(self): """ - async def __aenter__(self): + async def __aenter__(self) -> AsyncResource: return self - async def __aexit__(self, *args): + async def __aexit__(self, *args: object) -> None: await self.aclose() @@ -278,7 +280,7 @@ class SendStream(AsyncResource): __slots__ = () @abstractmethod - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Sends the given data through the stream, blocking if necessary. Args: @@ -304,7 +306,7 @@ async def send_all(self, data): """ @abstractmethod - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Block until it's possible that :meth:`send_all` might not block. This method may return early: it's possible that after it returns, @@ -384,7 +386,7 @@ class ReceiveStream(AsyncResource): __slots__ = () @abstractmethod - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Wait until there is data available on this stream, and then return some of it. @@ -412,10 +414,10 @@ async def receive_some(self, max_bytes=None): """ - def __aiter__(self): + def __aiter__(self) -> ReceiveStream: return self - async def __anext__(self): + async def __anext__(self) -> bytes | bytearray: data = await self.receive_some() if not data: raise StopAsyncIteration @@ -445,7 +447,7 @@ class HalfCloseableStream(Stream): __slots__ = () @abstractmethod - async def send_eof(self): + async def send_eof(self) -> None: """Send an end-of-file indication on this stream, if possible. The difference between :meth:`send_eof` and @@ -631,7 +633,7 @@ async def receive(self) -> ReceiveType: """ - def __aiter__(self): + def __aiter__(self) -> ReceiveChannel[ReceiveType]: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index 0e95e4e5c5..a9db49584f 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final): """ - def __init__(self, rate=0.0, autojump_threshold=inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. @@ -77,17 +77,17 @@ def __init__(self, rate=0.0, autojump_threshold=inf): self.rate = rate self.autojump_threshold = autojump_threshold - def __repr__(self): + def __repr__(self) -> str: return "".format( self.current_time(), self._rate, id(self) ) @property - def rate(self): + def rate(self) -> float: return self._rate @rate.setter - def rate(self, new_rate): + def rate(self, new_rate: float) -> None: if new_rate < 0: raise ValueError("rate must be >= 0") else: @@ -98,11 +98,11 @@ def rate(self, new_rate): self._rate = float(new_rate) @property - def autojump_threshold(self): + def autojump_threshold(self) -> float: return self._autojump_threshold @autojump_threshold.setter - def autojump_threshold(self, new_autojump_threshold): + def autojump_threshold(self, new_autojump_threshold: float) -> None: self._autojump_threshold = float(new_autojump_threshold) self._try_resync_autojump_threshold() @@ -112,7 +112,7 @@ def autojump_threshold(self, new_autojump_threshold): # API. Discussion: # # https://github.com/python-trio/trio/issues/1587 - def _try_resync_autojump_threshold(self): + def _try_resync_autojump_threshold(self) -> None: try: runner = GLOBAL_RUN_CONTEXT.runner if runner.is_guest: @@ -124,24 +124,24 @@ def _try_resync_autojump_threshold(self): # Invoked by the run loop when runner.clock_autojump_threshold is # exceeded. - def _autojump(self): + def _autojump(self) -> None: statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if 0 < jump < inf: self.jump(jump) - def _real_to_virtual(self, real): + def _real_to_virtual(self, real: float) -> float: real_offset = real - self._real_base virtual_offset = self._rate * real_offset return self._virtual_base + virtual_offset - def start_clock(self): + def start_clock(self) -> None: self._try_resync_autojump_threshold() - def current_time(self): + def current_time(self) -> float: return self._real_to_virtual(self._real_clock()) - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: virtual_timeout = deadline - self.current_time() if virtual_timeout <= 0: return 0 @@ -150,7 +150,7 @@ def deadline_to_sleep_time(self, deadline): else: return 999999999 - def jump(self, seconds): + def jump(self, seconds) -> None: """Manually advance the clock by the given number of seconds. Args: diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index f38123540f..42fc60614c 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -70,16 +70,24 @@ # # See: https://github.com/python-trio/trio/issues/53 +from __future__ import annotations + import attr from collections import OrderedDict +from typing import TYPE_CHECKING, cast +from collections.abc import Iterator +import math from .. import _core from .._util import Final +if TYPE_CHECKING: + from ._run import Task + @attr.s(frozen=True, slots=True) class _ParkingLotStatistics: - tasks_waiting = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False, slots=True) @@ -98,13 +106,13 @@ class ParkingLot(metaclass=Final): # {task: None}, we just want a deque where we can quickly delete random # items - _parked = attr.ib(factory=OrderedDict, init=False) + _parked: OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) - def __len__(self): + def __len__(self) -> int: """Returns the number of parked tasks.""" return len(self._parked) - def __bool__(self): + def __bool__(self) -> bool: """True if there are parked tasks, False otherwise.""" return bool(self._parked) @@ -113,7 +121,7 @@ def __bool__(self): # line (for false wakeups), then we could have it return a ticket that # abstracts the "place in line" concept. @_core.enable_ki_protection - async def park(self): + async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. @@ -128,13 +136,15 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) - def _pop_several(self, count): - for _ in range(min(count, len(self._parked))): + def _pop_several(self, count: int | float) -> Iterator[Task]: + if isinstance(count, float): + assert math.isinf(count) + for _ in range(cast(int, min(count, len(self._parked)))): task, _ = self._parked.popitem(last=False) yield task @_core.enable_ki_protection - def unpark(self, *, count=1): + def unpark(self, *, count: int | float) -> list[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If @@ -142,7 +152,7 @@ def unpark(self, *, count=1): are available and then returns successfully. Args: - count (int): the number of tasks to unpark. + count (int | math.inf): the number of tasks to unpark. """ tasks = list(self._pop_several(count)) @@ -150,12 +160,12 @@ def unpark(self, *, count=1): _core.reschedule(task) return tasks - def unpark_all(self): + def unpark_all(self) -> list[Task]: """Unpark all parked tasks.""" return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot, *, count=1): + def repark(self, new_lot: ParkingLot, *, count: int = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on @@ -194,7 +204,7 @@ async def main(): new_lot._parked[task] = None task.custom_sleep_data = new_lot - def repark_all(self, new_lot): + def repark_all(self, new_lot: ParkingLot) -> None: """Move all parked tasks from one :class:`ParkingLot` object to another. @@ -203,7 +213,7 @@ def repark_all(self, new_lot): """ return self.repark(new_lot, count=len(self)) - def statistics(self): + def statistics(self) -> _ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 2727fe1e89..0664cdf577 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -50,6 +50,10 @@ if TYPE_CHECKING: # An unfortunate name collision here with trio._util.Final from typing_extensions import Final as FinalT + from collections.abc import Coroutine + from types import FrameType + import outcome + import contextvars DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 @@ -274,7 +278,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope = attr.ib() + _scope: CancelScope = attr.ib() # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -284,31 +288,31 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled = attr.ib(default=False) + effectively_cancelled: bool = attr.ib(default=False) # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent = attr.ib(default=None, repr=False) + _parent: CancelStatus | None = attr.ib(default=None, repr=False) # All of the CancelStatuses that have this CancelStatus as their parent. - _children = attr.ib(factory=set, init=False, repr=False) + _children: set[CancelStatus] = attr.ib(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks = attr.ib(factory=set, init=False, repr=False) + _tasks: set[Task] = attr.ib(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._parent is not None: self._parent._children.add(self) self.recalculate() @@ -316,11 +320,11 @@ def __attrs_post_init__(self): # parent/children/tasks accessors are used by TaskStatus.started() @property - def parent(self): + def parent(self) -> CancelStatus | None: return self._parent @parent.setter - def parent(self, parent): + def parent(self, parent: CancelStatus) -> None: if self._parent is not None: self._parent._children.remove(self) self._parent = parent @@ -329,14 +333,14 @@ def parent(self, parent): self.recalculate() @property - def children(self): + def children(self) -> frozenset[CancelStatus]: return frozenset(self._children) @property - def tasks(self): + def tasks(self) -> frozenset[Task]: return frozenset(self._tasks) - def encloses(self, other): + def encloses(self, other: CancelStatus | None) -> bool: """Returns true if this cancel status is a direct or indirect parent of cancel status *other*, or if *other* is *self*. """ @@ -346,7 +350,7 @@ def encloses(self, other): other = other.parent return False - def close(self): + def close(self) -> None: self.parent = None # now we're not a child of self.parent anymore if self._tasks or self._children: # Cancel scopes weren't exited in opposite order of being @@ -406,7 +410,7 @@ def _mark_abandoned(self): for child in self._children: child._mark_abandoned() - def effective_deadline(self): + def effective_deadline(self) -> float: if self.effectively_cancelled: return -inf if self._parent is None or self._scope.shield: @@ -854,7 +858,7 @@ class NurseryManager: """ - strict_exception_groups = attr.ib(default=False) + strict_exception_groups: bool = attr.ib(default=False) @enable_ki_protection async def __aenter__(self): @@ -866,7 +870,12 @@ async def __aenter__(self): return self._nursery @enable_ki_protection - async def __aexit__(self, etype, exc, tb): + async def __aexit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: new_exc = await self._nursery._nested_child_finished(exc) # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. @@ -889,18 +898,18 @@ async def __aexit__(self, etype, exc, tb): # see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage del _, combined_error_from_nursery, value, new_exc - def __enter__(self): + def __enter__(self) -> None: raise RuntimeError( "use 'async with open_nursery(...)', not 'with open_nursery(...)'" ) - def __exit__(self): # pragma: no cover + def __exit__(self) -> None: # pragma: no cover assert False, """Never called, but should be defined""" def open_nursery( strict_exception_groups: bool | None = None, -) -> AbstractAsyncContextManager[Nursery]: +) -> NurseryManager: """Returns an async context manager which must be used to create a new `Nursery`. @@ -941,7 +950,12 @@ class Nursery(metaclass=NoPublicConstructor): in response to some external event. """ - def __init__(self, parent_task, cancel_scope, strict_exception_groups): + def __init__( + self, + parent_task: Task, + cancel_scope: CancelScope, + strict_exception_groups: bool, + ): self._parent_task = parent_task self._strict_exception_groups = strict_exception_groups parent_task._child_nurseries.append(self) @@ -952,8 +966,8 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups): # children. self.cancel_scope = cancel_scope assert self.cancel_scope._cancel_status is self._cancel_status - self._children = set() - self._pending_excs = [] + self._children: set[Task] = set() + self._pending_excs: list[BaseException] = [] # The "nested child" is how this code refers to the contents of the # nursery's 'async with' block, which acts like a child Task in all # the ways we can make it. @@ -963,17 +977,17 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups): self._closed = False @property - def child_tasks(self): + def child_tasks(self) -> frozenset[Task]: """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` objects which are still running.""" return frozenset(self._children) @property - def parent_task(self): + def parent_task(self) -> Task: "(`~trio.lowlevel.Task`): The Task that opened this nursery." return self._parent_task - def _add_exc(self, exc): + def _add_exc(self, exc: BaseException) -> None: self._pending_excs.append(exc) self.cancel_scope.cancel() @@ -1135,7 +1149,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): self._pending_starts -= 1 self._check_nursery_closed() - def __del__(self): + def __del__(self) -> None: assert not self._children @@ -1146,12 +1160,11 @@ def __del__(self): @attr.s(eq=False, hash=False, repr=False, slots=True) class Task(metaclass=NoPublicConstructor): - _parent_nursery = attr.ib() - coro = attr.ib() + _parent_nursery: Nursery | None = attr.ib() + coro: Coroutine[Any, outcome.Outcome[object], Any] = attr.ib() _runner = attr.ib() - name = attr.ib() - # PEP 567 contextvars context - context = attr.ib() + name: str = attr.ib() + context: contextvars.Context = attr.ib() _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: @@ -1167,24 +1180,27 @@ class Task(metaclass=NoPublicConstructor): # Tasks start out unscheduled. _next_send_fn = attr.ib(default=None) _next_send = attr.ib(default=None) - _abort_func = attr.ib(default=None) - custom_sleep_data = attr.ib(default=None) + _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( + default=None + ) + # could possible be set with a TypeVar + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() - _child_nurseries = attr.ib(factory=list) - _eventual_parent_nursery = attr.ib(default=None) + _child_nurseries: list[Nursery] = attr.ib(factory=list) + _eventual_parent_nursery: Nursery | None = attr.ib(default=None) # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points = attr.ib(default=0) - _schedule_points = attr.ib(default=0) + _cancel_points: int = attr.ib(default=0) + _schedule_points: int = attr.ib(default=0) - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def parent_nursery(self): + def parent_nursery(self) -> Nursery | None: """The nursery this task is inside (or None if this is the "init" task). @@ -1195,7 +1211,7 @@ def parent_nursery(self): return self._parent_nursery @property - def eventual_parent_nursery(self): + def eventual_parent_nursery(self) -> Nursery | None: """The nursery this task will be inside after it calls ``task_status.started()``. @@ -1207,7 +1223,7 @@ def eventual_parent_nursery(self): return self._eventual_parent_nursery @property - def child_nurseries(self): + def child_nurseries(self) -> list[Nursery]: """The nurseries this task contains. This is a list, with outer nurseries before inner nurseries. @@ -1215,7 +1231,7 @@ def child_nurseries(self): """ return list(self._child_nurseries) - def iter_await_frames(self): + def iter_await_frames(self) -> Iterator[tuple[FrameType, int]]: """Iterates recursively over the coroutine-like objects this task is waiting on, yielding the frame and line number at each frame. @@ -1240,11 +1256,11 @@ def print_stack_for_task(task): if hasattr(coro, "cr_frame"): # A real coroutine yield coro.cr_frame, coro.cr_frame.f_lineno - coro = coro.cr_await + coro = coro.cr_await # type: ignore # TODO elif hasattr(coro, "gi_frame"): # A generator decorated with @types.coroutine yield coro.gi_frame, coro.gi_frame.f_lineno - coro = coro.gi_yieldfrom + coro = coro.gi_yieldfrom # type: ignore # TODO elif coro.__class__.__name__ in [ "async_generator_athrow", "async_generator_asend", @@ -1268,9 +1284,9 @@ def print_stack_for_task(task): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status = attr.ib(default=None, repr=False) + _cancel_status: CancelStatus = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status): + def _activate_cancel_status(self, cancel_status: CancelStatus): if self._cancel_status is not None: self._cancel_status._tasks.remove(self) self._cancel_status = cancel_status @@ -1279,12 +1295,14 @@ def _activate_cancel_status(self, cancel_status): if self._cancel_status.effectively_cancelled: self._attempt_delivery_of_any_pending_cancel() - def _attempt_abort(self, raise_cancel): + def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None: # Either the abort succeeds, in which case we will reschedule the # task, or else it fails, in which case it will worry about # rescheduling itself (hopefully eventually calling reraise to raise # the given exception, but not necessarily). - success = self._abort_func(raise_cancel) + + # TODO: what happens if _abort_func is None? + success = self._abort_func(raise_cancel) # type: ignore if type(success) is not Abort: raise TrioInternalError("abort function must return Abort enum") # We only attempt to abort once per blocking call, regardless of @@ -1293,7 +1311,7 @@ def _attempt_abort(self, raise_cancel): if success is Abort.SUCCEEDED: self._runner.reschedule(self, capture(raise_cancel)) - def _attempt_delivery_of_any_pending_cancel(self): + def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return if not self._cancel_status.effectively_cancelled: @@ -1304,12 +1322,12 @@ def raise_cancel(): self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self): + def _attempt_delivery_of_pending_ki(self) -> None: assert self._runner.ki_pending if self._abort_func is None: return - def raise_cancel(): + def raise_cancel() -> NoReturn: self._runner.ki_pending = False raise KeyboardInterrupt @@ -2435,17 +2453,17 @@ def unrolled_run( class _TaskStatusIgnored: - def __repr__(self): + def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value=None): + def started(self, value: None = None): pass TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored() -def current_task(): +def current_task() -> Task: """Return the :class:`Task` object representing the current task. Returns: @@ -2459,7 +2477,7 @@ def current_task(): raise RuntimeError("must be called from async context") from None -def current_effective_deadline(): +def current_effective_deadline() -> float: """Returns the current effective deadline for the current task. This function examines all the cancellation scopes that are currently in @@ -2486,7 +2504,7 @@ def current_effective_deadline(): return current_task()._cancel_status.effective_deadline() -async def checkpoint(): +async def checkpoint() -> None: """A pure :ref:`checkpoint `. This checks for cancellation and allows other tasks to be scheduled, @@ -2513,7 +2531,7 @@ async def checkpoint(): await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """Issue a :ref:`checkpoint ` if the calling context has been cancelled. diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index c31b4fdbf3..191c2a76d4 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,12 +1,17 @@ +from __future__ import annotations import attr import trio from .abc import HalfCloseableStream from trio._util import Final +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .abc import SendStream, ReceiveStream, AsyncResource -async def aclose_forcefully(resource): + +async def aclose_forcefully(resource: AsyncResource) -> None: """Close an async resource or async generator immediately, without blocking to do any graceful cleanup. @@ -72,18 +77,18 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream = attr.ib() - receive_stream = attr.ib() + send_stream: SendStream = attr.ib() + receive_stream: ReceiveStream = attr.ib() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Calls ``self.send_stream.send_all``.""" return await self.send_stream.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls ``self.send_stream.wait_send_all_might_not_block``.""" return await self.send_stream.wait_send_all_might_not_block() - async def send_eof(self): + async def send_eof(self) -> None: """Shuts down the send side of the stream. If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, @@ -91,15 +96,15 @@ async def send_eof(self): """ if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() + return await self.send_stream.send_eof() # type: ignore else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) - async def aclose(self): + async def aclose(self) -> None: """Calls ``aclose`` on both underlying streams.""" try: await self.send_stream.aclose() diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 1e8dc16ebc..5bd6b68f1c 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -1,13 +1,18 @@ # "High-level" networking interface +from __future__ import annotations import errno from contextlib import contextmanager +from typing import TYPE_CHECKING import trio from . import socket as tsocket from ._util import ConflictDetector, Final from .abc import HalfCloseableStream, Listener +if TYPE_CHECKING: + from ._socket import _SocketType as SocketType + # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it # if we observe single reads filling up the whole buffer, at least within some @@ -57,7 +62,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -108,14 +113,14 @@ async def send_all(self, data): sent = await self.socket.send(remaining) total_sent += sent - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self.socket.fileno() == -1: raise trio.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() - async def send_eof(self): + async def send_eof(self) -> None: with self._send_conflict_detector: await trio.lowlevel.checkpoint() # On macOS, calling shutdown a second time raises ENOTCONN, but @@ -125,7 +130,7 @@ async def send_eof(self): with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None): if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -133,7 +138,7 @@ async def receive_some(self, max_bytes=None): with _translate_socket_errors_to_stream_errors(): return await self.socket.recv(max_bytes) - async def aclose(self): + async def aclose(self) -> None: self.socket.close() await trio.lowlevel.checkpoint() @@ -330,7 +335,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -346,7 +351,7 @@ def __init__(self, socket): self.socket = socket - async def accept(self): + async def accept(self) -> SocketStream: """Accept an incoming connection. Returns: @@ -374,7 +379,7 @@ async def accept(self): else: return SocketStream(sock) - async def aclose(self): + async def aclose(self) -> None: """Close this listener and its underlying socket.""" self.socket.close() await trio.lowlevel.checkpoint() diff --git a/trio/_socket.py b/trio/_socket.py index 2889f48113..38a6fa0019 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import os import sys import select import socket as _stdlib_socket from functools import wraps as _wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import idna as _idna import trio from . import _core +if TYPE_CHECKING: + from collections.abc import Iterable + from typing_extensions import Self + # Usage: # @@ -429,7 +435,7 @@ def __init__(self): class _SocketType(SocketType): - def __init__(self, sock): + def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. @@ -472,44 +478,45 @@ def __getattr__(self, name): return getattr(self._sock, name) raise AttributeError(name) - def __dir__(self): - return super().__dir__() + list(self._forward) + def __dir__(self) -> Iterable[str]: + yield from super().__dir__() + yield from self._forward - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: object) -> None: return self._sock.__exit__(*exc_info) @property - def family(self): + def family(self) -> _stdlib_socket.AddressFamily: return self._sock.family @property - def type(self): + def type(self) -> _stdlib_socket.SocketKind: return self._sock.type @property - def proto(self): + def proto(self) -> int: return self._sock.proto @property - def did_shutdown_SHUT_WR(self): + def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR - def __repr__(self): + def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self): + def dup(self) -> _SocketType: """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) - def close(self): + def close(self) -> None: if self._sock.fileno() != -1: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address): + async def bind(self, address: tuple[Any, ...] | str | bytes) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -518,7 +525,7 @@ async def bind(self, address): ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) - return await trio.to_thread.run_sync(self._sock.bind, address) + return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return] else: # POSIX actually says that bind can return EWOULDBLOCK and # complete asynchronously, like connect. But in practice AFAICT @@ -527,14 +534,14 @@ async def bind(self, address): await trio.lowlevel.checkpoint() return self._sock.bind(address) - def shutdown(self, flag): + def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True - def is_readable(self): + def is_readable(self) -> bool: # use select.select on Windows, and select.poll everywhere else if sys.platform == "win32": rready, _, _ = select.select([self._sock], [], [], 0) @@ -543,7 +550,7 @@ def is_readable(self): p.register(self._sock, select.POLLIN) return bool(p.poll(0)) - async def wait_writable(self): + async def wait_writable(self) -> None: await _core.wait_writable(self._sock) async def _resolve_address_nocp(self, address, *, local): diff --git a/trio/_sync.py b/trio/_sync.py index 8d2fdc0a2d..08eae512d2 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,3 +1,6 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + import math import attr @@ -8,10 +11,14 @@ from ._core import enable_ki_protection, ParkingLot from ._util import Final +if TYPE_CHECKING: + from ._core import Task + from ._core._parking_lot import _ParkingLotStatistics + @attr.s(frozen=True) class _EventStatistics: - tasks_waiting = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(repr=False, eq=False, hash=False, slots=True) @@ -41,15 +48,15 @@ class Event(metaclass=Final): """ - _tasks = attr.ib(factory=set, init=False) - _flag = attr.ib(default=False, init=False) + _tasks: set[Task] = attr.ib(factory=set, init=False) + _flag: bool = attr.ib(default=False, init=False) - def is_set(self): + def is_set(self) -> bool: """Return the current value of the internal flag.""" return self._flag @enable_ki_protection - def set(self): + def set(self) -> None: """Set the internal flag value to True, and wake any waiting tasks.""" if not self._flag: self._flag = True @@ -57,7 +64,7 @@ def set(self): _core.reschedule(task) self._tasks.clear() - async def wait(self): + async def wait(self) -> None: """Block until the internal flag value becomes True. If it's already True, then this method returns immediately. @@ -75,7 +82,7 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) - def statistics(self): + def statistics(self) -> _EventStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -89,20 +96,20 @@ def statistics(self): class AsyncContextManagerMixin: @enable_ki_protection - async def __aenter__(self): - await self.acquire() + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] @enable_ki_protection - async def __aexit__(self, *args): - self.release() + async def __aexit__(self, *args: object) -> None: + self.release() # type: ignore[attr-defined] @attr.s(frozen=True) class _CapacityLimiterStatistics: - borrowed_tokens = attr.ib() - total_tokens = attr.ib() - borrowers = attr.ib() - tasks_waiting = attr.ib() + borrowed_tokens: int = attr.ib() + total_tokens: int | float = attr.ib() + borrowers: list[Task] = attr.ib() + tasks_waiting: int = attr.ib() class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): @@ -159,22 +166,23 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, total_tokens): + # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing + def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers = set() + self._borrowers: set[Task] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers = {} + self._pending_borrowers: dict[Task, Task] = {} # invoke the property setter for validation - self.total_tokens = total_tokens + self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @property - def total_tokens(self): + def total_tokens(self) -> int | float: """The total capacity available. You can change :attr:`total_tokens` by assigning to this attribute. If @@ -189,7 +197,7 @@ def total_tokens(self): return self._total_tokens @total_tokens.setter - def total_tokens(self, new_total_tokens): + def total_tokens(self, new_total_tokens: int | float) -> None: if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf: raise TypeError("total_tokens must be an int or math.inf") if new_total_tokens < 1: @@ -197,23 +205,23 @@ def total_tokens(self, new_total_tokens): self._total_tokens = new_total_tokens self._wake_waiters() - def _wake_waiters(self): + def _wake_waiters(self) -> None: available = self._total_tokens - len(self._borrowers) for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) @property - def borrowed_tokens(self): + def borrowed_tokens(self) -> int: """The amount of capacity that's currently in use.""" return len(self._borrowers) @property - def available_tokens(self): + def available_tokens(self) -> int | float: """The amount of capacity that's available to use.""" return self.total_tokens - self.borrowed_tokens @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Borrow a token from the sack, without blocking. Raises: @@ -225,7 +233,7 @@ def acquire_nowait(self): self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: Task) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -253,7 +261,7 @@ def acquire_on_behalf_of_nowait(self, borrower): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Borrow a token from the sack, blocking if necessary. Raises: @@ -264,7 +272,7 @@ async def acquire(self): await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower: Task) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -293,7 +301,7 @@ async def acquire_on_behalf_of(self, borrower): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Put a token back into the sack. Raises: @@ -304,7 +312,7 @@ def release(self): self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Task) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: @@ -319,7 +327,7 @@ def release_on_behalf_of(self, borrower): self._borrowers.remove(borrower) self._wake_waiters() - def statistics(self): + def statistics(self) -> _CapacityLimiterStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -373,7 +381,7 @@ class Semaphore(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, initial_value, *, max_value=None): + def __init__(self, initial_value: int, *, max_value: int | None = None): if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -391,7 +399,7 @@ def __init__(self, initial_value, *, max_value=None): self._value = initial_value self._max_value = max_value - def __repr__(self): + def __repr__(self) -> str: if self._max_value is None: max_value_str = "" else: @@ -401,17 +409,17 @@ def __repr__(self): ) @property - def value(self): + def value(self) -> int: """The current value of the semaphore.""" return self._value @property - def max_value(self): + def max_value(self) -> int | None: """The maximum allowed value. May be None to indicate no limit.""" return self._max_value @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to decrement the semaphore value, without blocking. Raises: @@ -425,7 +433,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary to avoid letting it drop below zero. @@ -439,7 +447,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Increment the semaphore value, possibly waking a task blocked in :meth:`acquire`. @@ -456,7 +464,7 @@ def release(self): raise ValueError("semaphore released too many times") self._value += 1 - def statistics(self): + def statistics(self) -> _ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -470,17 +478,17 @@ def statistics(self): @attr.s(frozen=True) class _LockStatistics: - locked = attr.ib() - owner = attr.ib() - tasks_waiting = attr.ib() + locked: bool = attr.ib() + owner: Task | None = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False, repr=False) class _LockImpl(AsyncContextManagerMixin): - _lot = attr.ib(factory=ParkingLot, init=False) - _owner = attr.ib(default=None, init=False) + _lot: ParkingLot = attr.ib(factory=ParkingLot, init=False) + _owner: Task | None = attr.ib(default=None, init=False) - def __repr__(self): + def __repr__(self) -> str: if self.locked(): s1 = "locked" s2 = f" with {len(self._lot)} waiters" @@ -491,7 +499,7 @@ def __repr__(self): s1, self.__class__.__name__, id(self), s2 ) - def locked(self): + def locked(self) -> bool: """Check whether the lock is currently held. Returns: @@ -501,7 +509,7 @@ def locked(self): return self._owner is not None @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the lock, without blocking. Raises: @@ -519,7 +527,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock, blocking if necessary.""" await trio.lowlevel.checkpoint_if_cancelled() try: @@ -533,7 +541,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Release the lock. Raises: @@ -548,7 +556,7 @@ def release(self): else: self._owner = None - def statistics(self): + def statistics(self) -> _LockStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -644,8 +652,8 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): @attr.s(frozen=True) class _ConditionStatistics: - tasks_waiting = attr.ib() - lock_statistics = attr.ib() + tasks_waiting: int = attr.ib() + lock_statistics: _LockStatistics = attr.ib() class Condition(AsyncContextManagerMixin, metaclass=Final): @@ -663,7 +671,7 @@ class Condition(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, lock=None): + def __init__(self, lock: Lock | None = None): if lock is None: lock = Lock() if not type(lock) is Lock: @@ -671,7 +679,7 @@ def __init__(self, lock=None): self._lock = lock self._lot = trio.lowlevel.ParkingLot() - def locked(self): + def locked(self) -> bool: """Check whether the underlying lock is currently held. Returns: @@ -680,7 +688,7 @@ def locked(self): """ return self._lock.locked() - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the underlying lock, without blocking. Raises: @@ -689,16 +697,16 @@ def acquire_nowait(self): """ return self._lock.acquire_nowait() - async def acquire(self): + async def acquire(self) -> None: """Acquire the underlying lock, blocking if necessary.""" await self._lock.acquire() - def release(self): + def release(self) -> None: """Release the underlying lock.""" self._lock.release() @enable_ki_protection - async def wait(self): + async def wait(self) -> None: """Wait for another task to call :meth:`notify` or :meth:`notify_all`. @@ -733,7 +741,7 @@ async def wait(self): await self.acquire() raise - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """Wake one or more tasks that are blocked in :meth:`wait`. Args: @@ -747,7 +755,7 @@ def notify(self, n=1): raise RuntimeError("must hold the lock to notify") self._lot.repark(self._lock._lot, count=n) - def notify_all(self): + def notify_all(self) -> None: """Wake all tasks that are currently blocked in :meth:`wait`. Raises: @@ -758,7 +766,7 @@ def notify_all(self): raise RuntimeError("must hold the lock to notify") self._lot.repark_all(self._lock._lot) - def statistics(self): + def statistics(self) -> _ConditionStatistics: r"""Return an object containing debugging information. Currently the following fields are defined: diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index e54af12444..53dc90d5d9 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8317152103559871, + "completenessScore": 0.8737864077669902, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 514, - "withUnknownType": 103 + "withKnownType": 540, + "withUnknownType": 77 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -45,91 +45,36 @@ } ], "otherSymbolCounts": { - "withAmbiguousType": 14, - "withKnownType": 244, - "withUnknownType": 224 + "withAmbiguousType": 8, + "withKnownType": 441, + "withUnknownType": 143 }, "packageName": "trio", "symbols": [ "trio.run", - "trio.current_effective_deadline", - "trio._core._run._TaskStatusIgnored.__repr__", "trio._core._run._TaskStatusIgnored.started", "trio.current_time", - "trio._core._run.Nursery.__init__", - "trio._core._run.Nursery.child_tasks", - "trio._core._run.Nursery.parent_task", "trio._core._run.Nursery.start_soon", "trio._core._run.Nursery.start", - "trio._core._run.Nursery.__del__", - "trio._sync.Event.is_set", - "trio._sync.Event.wait", - "trio._sync.Event.statistics", - "trio._sync.CapacityLimiter.__init__", - "trio._sync.CapacityLimiter.__repr__", - "trio._sync.CapacityLimiter.total_tokens", - "trio._sync.CapacityLimiter.borrowed_tokens", - "trio._sync.CapacityLimiter.available_tokens", - "trio._sync.CapacityLimiter.statistics", - "trio._sync.Semaphore.__init__", - "trio._sync.Semaphore.__repr__", - "trio._sync.Semaphore.value", - "trio._sync.Semaphore.max_value", - "trio._sync.Semaphore.statistics", - "trio._sync.Lock", - "trio._sync._LockImpl.__repr__", - "trio._sync._LockImpl.locked", - "trio._sync._LockImpl.statistics", - "trio._sync.StrictFIFOLock", - "trio._sync.Condition.__init__", - "trio._sync.Condition.locked", - "trio._sync.Condition.acquire_nowait", - "trio._sync.Condition.acquire", - "trio._sync.Condition.release", - "trio._sync.Condition.notify", - "trio._sync.Condition.notify_all", - "trio._sync.Condition.statistics", - "trio.aclose_forcefully", - "trio._highlevel_generic.StapledStream", - "trio._highlevel_generic.StapledStream.send_stream", - "trio._highlevel_generic.StapledStream.receive_stream", - "trio._highlevel_generic.StapledStream.send_all", - "trio._highlevel_generic.StapledStream.wait_send_all_might_not_block", - "trio._highlevel_generic.StapledStream.send_eof", - "trio._highlevel_generic.StapledStream.receive_some", - "trio._highlevel_generic.StapledStream.aclose", - "trio._abc.HalfCloseableStream", - "trio._abc.HalfCloseableStream.send_eof", - "trio._abc.Stream", - "trio._abc.SendStream", - "trio._abc.SendStream.send_all", - "trio._abc.SendStream.wait_send_all_might_not_block", - "trio._abc.AsyncResource.aclose", - "trio._abc.AsyncResource.__aenter__", - "trio._abc.AsyncResource.__aexit__", - "trio._abc.ReceiveStream", - "trio._abc.ReceiveStream.receive_some", - "trio._abc.ReceiveStream.__aiter__", - "trio._abc.ReceiveStream.__anext__", - "trio._channel.MemorySendChannel", - "trio._abc.SendChannel", - "trio._channel.MemoryReceiveChannel", - "trio._abc.ReceiveChannel", - "trio._abc.ReceiveChannel.__aiter__", - "trio._highlevel_socket.SocketStream", "trio._highlevel_socket.SocketStream.__init__", + "trio._socket._SocketType.__getattr__", + "trio._socket._SocketType.accept", + "trio._socket._SocketType.connect", + "trio._socket._SocketType.recv", + "trio._socket._SocketType.recv_into", + "trio._socket._SocketType.recvfrom", + "trio._socket._SocketType.recvfrom_into", + "trio._socket._SocketType.recvmsg", + "trio._socket._SocketType.recvmsg_into", + "trio._socket._SocketType.send", + "trio._socket._SocketType.sendto", + "trio._socket._SocketType.sendmsg", "trio._highlevel_socket.SocketStream.send_all", - "trio._highlevel_socket.SocketStream.wait_send_all_might_not_block", - "trio._highlevel_socket.SocketStream.send_eof", "trio._highlevel_socket.SocketStream.receive_some", - "trio._highlevel_socket.SocketStream.aclose", "trio._highlevel_socket.SocketStream.setsockopt", "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketListener", "trio._highlevel_socket.SocketListener.__init__", - "trio._highlevel_socket.SocketListener.accept", - "trio._highlevel_socket.SocketListener.aclose", - "trio._abc.Listener", "trio._abc.Listener.accept", "trio.open_file", "trio.wrap_file", @@ -147,7 +92,6 @@ "trio._path.AsyncAutoWrapperType.generate_wraps", "trio._path.AsyncAutoWrapperType.generate_magic", "trio._path.AsyncAutoWrapperType.generate_iter", - "trio._subprocess.Process", "trio._subprocess.Process.encoding", "trio._subprocess.Process.errors", "trio._subprocess.Process.__init__", @@ -163,7 +107,6 @@ "trio._subprocess.Process.args", "trio._subprocess.Process.pid", "trio.run_process", - "trio._ssl.SSLStream", "trio._ssl.SSLStream.__init__", "trio._ssl.SSLStream.__getattr__", "trio._ssl.SSLStream.__setattr__", @@ -188,7 +131,6 @@ "trio._dtls.DTLSEndpoint.connect", "trio._dtls.DTLSEndpoint.socket", "trio._dtls.DTLSEndpoint.incoming_packets_buffer", - "trio._dtls.DTLSChannel", "trio._dtls.DTLSChannel.__init__", "trio._dtls.DTLSChannel.close", "trio._dtls.DTLSChannel.__enter__", @@ -200,7 +142,6 @@ "trio._dtls.DTLSChannel.set_ciphertext_mtu", "trio._dtls.DTLSChannel.get_cleartext_mtu", "trio._dtls.DTLSChannel.statistics", - "trio._abc.Channel", "trio.serve_listeners", "trio.open_tcp_stream", "trio.open_tcp_listeners", @@ -210,9 +151,6 @@ "trio.open_ssl_over_tcp_listeners", "trio.serve_ssl_over_tcp", "trio.__deprecated_attributes__", - "trio._abc.Clock.start_clock", - "trio._abc.Clock.current_time", - "trio._abc.Clock.deadline_to_sleep_time", "trio._abc.Instrument.before_run", "trio._abc.Instrument.after_run", "trio._abc.Instrument.task_spawned", @@ -229,22 +167,6 @@ "trio.from_thread.run_sync", "trio.lowlevel.cancel_shielded_checkpoint", "trio.lowlevel.currently_ki_protected", - "trio._core._run.Task.coro", - "trio._core._run.Task.name", - "trio._core._run.Task.context", - "trio._core._run.Task.custom_sleep_data", - "trio._core._run.Task.__repr__", - "trio._core._run.Task.parent_nursery", - "trio._core._run.Task.eventual_parent_nursery", - "trio._core._run.Task.child_nurseries", - "trio._core._run.Task.iter_await_frames", - "trio.lowlevel.checkpoint", - "trio.lowlevel.current_task", - "trio._core._parking_lot.ParkingLot.__len__", - "trio._core._parking_lot.ParkingLot.__bool__", - "trio._core._parking_lot.ParkingLot.unpark_all", - "trio._core._parking_lot.ParkingLot.repark_all", - "trio._core._parking_lot.ParkingLot.statistics", "trio._core._unbounded_queue.UnboundedQueue.__repr__", "trio._core._unbounded_queue.UnboundedQueue.qsize", "trio._core._unbounded_queue.UnboundedQueue.empty", @@ -268,7 +190,6 @@ "trio.lowlevel.add_instrument", "trio.lowlevel.current_clock", "trio.lowlevel.current_root_task", - "trio.lowlevel.checkpoint_if_cancelled", "trio.lowlevel.spawn_system_task", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", @@ -276,7 +197,6 @@ "trio.lowlevel.start_thread_soon", "trio.lowlevel.start_guest_run", "trio.lowlevel.open_process", - "trio._unix_pipes.FdStream", "trio.socket.fromfd", "trio.socket.from_stdlib_socket", "trio.socket.getprotobyname", @@ -287,14 +207,6 @@ "trio.socket.set_custom_hostname_resolver", "trio.socket.set_custom_socket_factory", "trio.testing.wait_all_tasks_blocked", - "trio._core._mock_clock.MockClock", - "trio._core._mock_clock.MockClock.__init__", - "trio._core._mock_clock.MockClock.__repr__", - "trio._core._mock_clock.MockClock.rate", - "trio._core._mock_clock.MockClock.autojump_threshold", - "trio._core._mock_clock.MockClock.start_clock", - "trio._core._mock_clock.MockClock.current_time", - "trio._core._mock_clock.MockClock.deadline_to_sleep_time", "trio._core._mock_clock.MockClock.jump", "trio.testing.trio_test", "trio.testing.assert_checkpoints", @@ -302,7 +214,6 @@ "trio.testing.check_one_way_stream", "trio.testing.check_two_way_stream", "trio.testing.check_half_closeable_stream", - "trio.testing._memory_streams.MemorySendStream", "trio.testing._memory_streams.MemorySendStream.__init__", "trio.testing._memory_streams.MemorySendStream.send_all", "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block", @@ -313,7 +224,6 @@ "trio.testing._memory_streams.MemorySendStream.send_all_hook", "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block_hook", "trio.testing._memory_streams.MemorySendStream.close_hook", - "trio.testing._memory_streams.MemoryReceiveStream", "trio.testing._memory_streams.MemoryReceiveStream.__init__", "trio.testing._memory_streams.MemoryReceiveStream.receive_some", "trio.testing._memory_streams.MemoryReceiveStream.close", From 0050ca2a658b934dfdabec4668318dbf283de48f Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 4 Jul 2023 15:29:07 +0200 Subject: [PATCH 02/21] changes after review --- trio/_abc.py | 9 ++++++--- trio/_core/_multierror.py | 18 ++++++++++++++---- trio/_core/_parking_lot.py | 6 +++--- trio/_core/_run.py | 11 ++++++++--- trio/_dtls.py | 31 +++++++++++++++++++++++++------ trio/_highlevel_socket.py | 4 ++-- trio/_socket.py | 10 ++++++++-- trio/_sync.py | 30 +++++++++++++++--------------- trio/_tests/verify_types.json | 9 ++------- trio/_util.py | 10 +++++++++- trio/testing/_fake_net.py | 13 +++++++++++-- 11 files changed, 103 insertions(+), 48 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 403e07a1de..036a20b567 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,9 +1,12 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, TypeVar, TYPE_CHECKING import trio +if TYPE_CHECKING: + from typing_extensions import Self + # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a # __dict__ onto subclasses. @@ -414,7 +417,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """ - def __aiter__(self) -> ReceiveStream: + def __aiter__(self) -> Self: return self async def __anext__(self) -> bytes | bytearray: @@ -633,7 +636,7 @@ async def receive(self) -> ReceiveType: """ - def __aiter__(self) -> ReceiveChannel[ReceiveType]: + def __aiter__(self) -> Self: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 9e69928162..6e964ba038 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,15 +1,19 @@ +from __future__ import annotations import sys import warnings import attr from trio._deprecate import warn_deprecated +from typing import TYPE_CHECKING if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup, print_exception else: from traceback import print_exception +if TYPE_CHECKING: + from types import TracebackType ################################################################ # MultiError ################################################################ @@ -130,11 +134,16 @@ class MultiErrorCatcher: def __enter__(self): pass - def __exit__(self, etype, exc, tb): - if exc is not None: - filtered_exc = _filter_impl(self._handler, exc) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + if exc_val is not None: + filtered_exc = _filter_impl(self._handler, exc_val) - if filtered_exc is exc: + if filtered_exc is exc_val: # Let the interpreter re-raise it return False if filtered_exc is None: @@ -154,6 +163,7 @@ def __exit__(self, etype, exc, tb): # delete references from locals to avoid creating cycles # see test_MultiError_catch_doesnt_create_cyclic_garbage del _, filtered_exc, value + return False class MultiError(BaseExceptionGroup): diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 42fc60614c..5ef22ca247 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -86,7 +86,7 @@ @attr.s(frozen=True, slots=True) -class _ParkingLotStatistics: +class ParkingLotStatistics: tasks_waiting: int = attr.ib() @@ -213,7 +213,7 @@ def repark_all(self, new_lot: ParkingLot) -> None: """ return self.repark(new_lot, count=len(self)) - def statistics(self) -> _ParkingLotStatistics: + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -222,4 +222,4 @@ def statistics(self) -> _ParkingLotStatistics: :meth:`park` method. """ - return _ParkingLotStatistics(tasks_waiting=len(self._parked)) + return ParkingLotStatistics(tasks_waiting=len(self._parked)) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 0664cdf577..b7a8359698 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -903,13 +903,18 @@ def __enter__(self) -> None: "use 'async with open_nursery(...)', not 'with open_nursery(...)'" ) - def __exit__(self) -> None: # pragma: no cover + def __exit__( + self, # pragma: no cover + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: assert False, """Never called, but should be defined""" def open_nursery( strict_exception_groups: bool | None = None, -) -> NurseryManager: +) -> AbstractAsyncContextManager[Nursery]: """Returns an async context manager which must be used to create a new `Nursery`. @@ -2456,7 +2461,7 @@ class _TaskStatusIgnored: def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value: None = None): + def started(self, value: Any = None): pass diff --git a/trio/_dtls.py b/trio/_dtls.py index 910637455a..9cb8a5d03c 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -6,6 +6,7 @@ # Hopefully they fix this before implementing DTLS 1.3, because it's a very different # protocol, and it's probably impossible to pull tricks like we do here. +from __future__ import annotations import struct import hmac import os @@ -14,12 +15,16 @@ import weakref import errno import warnings +from typing import TYPE_CHECKING import attr import trio from trio._util import NoPublicConstructor, Final +if TYPE_CHECKING: + from types import TracebackType + MAX_UDP_PACKET_SIZE = 65527 @@ -809,7 +814,7 @@ def _check_replaced(self): # DTLS where packets are all independent and can be lost anyway. We do at least need # to handle receiving it properly though, which might be easier if we send it... - def close(self): + def close(self) -> None: """Close this connection. `DTLSChannel`\\s don't actually own any OS-level resources – the @@ -833,8 +838,13 @@ def close(self): def __enter__(self): return self - def __exit__(self, *args): - self.close() + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + return self.close() async def aclose(self): """Close this connection, but asynchronously. @@ -1167,12 +1177,16 @@ def __del__(self): f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self ) - def close(self): + def close(self) -> None: """Close this socket, and all associated DTLS connections. This object can also be used as a context manager. """ + # Do nothing if this object was never fully constructed + if self.socket is None: + return + self._closed = True self.socket.close() for stream in list(self._streams.values()): @@ -1182,8 +1196,13 @@ def close(self): def __enter__(self): return self - def __exit__(self, *args): - self.close() + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + return self.close() def _check_closed(self): if self._closed: diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 5bd6b68f1c..b3641eb683 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -3,7 +3,7 @@ import errno from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import trio from . import socket as tsocket @@ -130,7 +130,7 @@ async def send_eof(self) -> None: with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes: int | None = None): + async def receive_some(self, max_bytes: int | None = None) -> Any: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: diff --git a/trio/_socket.py b/trio/_socket.py index 38a6fa0019..e25e035f06 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from collections.abc import Iterable from typing_extensions import Self + from types import TracebackType # Usage: @@ -485,8 +486,13 @@ def __dir__(self) -> Iterable[str]: def __enter__(self) -> Self: return self - def __exit__(self, *exc_info: object) -> None: - return self._sock.__exit__(*exc_info) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + return self._sock.__exit__(exc_type, exc_val, exc_tb) @property def family(self) -> _stdlib_socket.AddressFamily: diff --git a/trio/_sync.py b/trio/_sync.py index 08eae512d2..54b4c64c47 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -13,11 +13,11 @@ if TYPE_CHECKING: from ._core import Task - from ._core._parking_lot import _ParkingLotStatistics + from ._core._parking_lot import ParkingLotStatistics @attr.s(frozen=True) -class _EventStatistics: +class EventStatistics: tasks_waiting: int = attr.ib() @@ -82,7 +82,7 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) - def statistics(self) -> _EventStatistics: + def statistics(self) -> EventStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -91,7 +91,7 @@ def statistics(self) -> _EventStatistics: :meth:`wait` method. """ - return _EventStatistics(tasks_waiting=len(self._tasks)) + return EventStatistics(tasks_waiting=len(self._tasks)) class AsyncContextManagerMixin: @@ -105,7 +105,7 @@ async def __aexit__(self, *args: object) -> None: @attr.s(frozen=True) -class _CapacityLimiterStatistics: +class CapacityLimiterStatistics: borrowed_tokens: int = attr.ib() total_tokens: int | float = attr.ib() borrowers: list[Task] = attr.ib() @@ -327,7 +327,7 @@ def release_on_behalf_of(self, borrower: Task) -> None: self._borrowers.remove(borrower) self._wake_waiters() - def statistics(self) -> _CapacityLimiterStatistics: + def statistics(self) -> CapacityLimiterStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -344,7 +344,7 @@ def statistics(self) -> _CapacityLimiterStatistics: :meth:`acquire_on_behalf_of` methods. """ - return _CapacityLimiterStatistics( + return CapacityLimiterStatistics( borrowed_tokens=len(self._borrowers), total_tokens=self._total_tokens, # Use a list instead of a frozenset just in case we start to allow @@ -464,7 +464,7 @@ def release(self) -> None: raise ValueError("semaphore released too many times") self._value += 1 - def statistics(self) -> _ParkingLotStatistics: + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -477,7 +477,7 @@ def statistics(self) -> _ParkingLotStatistics: @attr.s(frozen=True) -class _LockStatistics: +class LockStatistics: locked: bool = attr.ib() owner: Task | None = attr.ib() tasks_waiting: int = attr.ib() @@ -556,7 +556,7 @@ def release(self) -> None: else: self._owner = None - def statistics(self) -> _LockStatistics: + def statistics(self) -> LockStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -568,7 +568,7 @@ def statistics(self) -> _LockStatistics: :meth:`acquire` method. """ - return _LockStatistics( + return LockStatistics( locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot) ) @@ -651,9 +651,9 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): @attr.s(frozen=True) -class _ConditionStatistics: +class ConditionStatistics: tasks_waiting: int = attr.ib() - lock_statistics: _LockStatistics = attr.ib() + lock_statistics: LockStatistics = attr.ib() class Condition(AsyncContextManagerMixin, metaclass=Final): @@ -766,7 +766,7 @@ def notify_all(self) -> None: raise RuntimeError("must hold the lock to notify") self._lot.repark_all(self._lock._lot) - def statistics(self) -> _ConditionStatistics: + def statistics(self) -> ConditionStatistics: r"""Return an object containing debugging information. Currently the following fields are defined: @@ -777,6 +777,6 @@ def statistics(self) -> _ConditionStatistics: :class:`Lock`\s :meth:`~Lock.statistics` method. """ - return _ConditionStatistics( + return ConditionStatistics( tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics() ) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 53dc90d5d9..791c416e49 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 441, - "withUnknownType": 143 + "withKnownType": 430, + "withUnknownType": 138 }, "packageName": "trio", "symbols": [ @@ -70,7 +70,6 @@ "trio._socket._SocketType.sendto", "trio._socket._SocketType.sendmsg", "trio._highlevel_socket.SocketStream.send_all", - "trio._highlevel_socket.SocketStream.receive_some", "trio._highlevel_socket.SocketStream.setsockopt", "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketListener", @@ -124,17 +123,13 @@ "trio._ssl.SSLListener.aclose", "trio._dtls.DTLSEndpoint.__init__", "trio._dtls.DTLSEndpoint.__del__", - "trio._dtls.DTLSEndpoint.close", "trio._dtls.DTLSEndpoint.__enter__", - "trio._dtls.DTLSEndpoint.__exit__", "trio._dtls.DTLSEndpoint.serve", "trio._dtls.DTLSEndpoint.connect", "trio._dtls.DTLSEndpoint.socket", "trio._dtls.DTLSEndpoint.incoming_packets_buffer", "trio._dtls.DTLSChannel.__init__", - "trio._dtls.DTLSChannel.close", "trio._dtls.DTLSChannel.__enter__", - "trio._dtls.DTLSChannel.__exit__", "trio._dtls.DTLSChannel.aclose", "trio._dtls.DTLSChannel.do_handshake", "trio._dtls.DTLSChannel.send", diff --git a/trio/_util.py b/trio/_util.py index 89a2dea7de..090204f6e3 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -12,6 +12,9 @@ import trio +if t.TYPE_CHECKING: + from types import TracebackType + # Equivalent to the C function raise(), which Python doesn't wrap if os.name == "nt": # On Windows, os.kill exists but is really weird. @@ -188,7 +191,12 @@ def __enter__(self): else: self._held = True - def __exit__(self, *args): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self._held = False diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 9df5ab5b6c..26348f2de2 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -6,15 +6,19 @@ # - TCP # - UDP broadcast +from __future__ import annotations import trio import attr import ipaddress import errno import os -from typing import Union, Optional +from typing import Union, Optional, TYPE_CHECKING from trio._util import Final, NoPublicConstructor +if TYPE_CHECKING: + from types import TracebackType + IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -337,7 +341,12 @@ def setsockopt(self, level, item, value): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.close() async def send(self, data, flags=0): From 8b8383ba39b553f479e8ea213cc983b0dd3a2856 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 4 Jul 2023 15:47:06 +0200 Subject: [PATCH 03/21] fix tests, reverting accidentally removing a default param --- trio/_core/_parking_lot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 5ef22ca247..ecf22652d8 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -69,7 +69,6 @@ # unpark is called. # # See: https://github.com/python-trio/trio/issues/53 - from __future__ import annotations import attr @@ -138,13 +137,13 @@ def abort_fn(_): def _pop_several(self, count: int | float) -> Iterator[Task]: if isinstance(count, float): - assert math.isinf(count) + assert math.isinf(count), "Cannot pop a non-integer number of tasks." for _ in range(cast(int, min(count, len(self._parked)))): task, _ = self._parked.popitem(last=False) yield task @_core.enable_ki_protection - def unpark(self, *, count: int | float) -> list[Task]: + def unpark(self, *, count: int | float = 1) -> list[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If From b1f9f6f6440a09cb9c9c231a18ba2bdab75bb93a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 4 Jul 2023 15:55:17 +0200 Subject: [PATCH 04/21] fix all __aexit --- trio/_abc.py | 8 +++++++- trio/_socket.py | 11 ++++++++--- trio/_sync.py | 8 +++++++- trio/testing/_check_streams.py | 12 +++++++++++- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 036a20b567..21577310d6 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from types import TracebackType # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a @@ -260,7 +261,12 @@ async def aclose(self) -> None: async def __aenter__(self) -> AsyncResource: return self - async def __aexit__(self, *args: object) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: await self.aclose() diff --git a/trio/_socket.py b/trio/_socket.py index e25e035f06..8e98ce8fa3 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -39,8 +39,13 @@ def _is_blocking_io_error(self, exc): async def __aenter__(self): await trio.lowlevel.checkpoint_if_cancelled() - async def __aexit__(self, etype, value, tb): - if value is not None and self._is_blocking_io_error(value): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + if exc_val is not None and self._is_blocking_io_error(exc_val): # Discard the exception and fall through to the code below the # block return True @@ -522,7 +527,7 @@ def close(self) -> None: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address: tuple[Any, ...] | str | bytes) -> None: + async def bind(self, address: tuple[object, ...] | str | bytes) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") diff --git a/trio/_sync.py b/trio/_sync.py index 54b4c64c47..164444771f 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from ._core import Task from ._core._parking_lot import ParkingLotStatistics + from types import TracebackType @attr.s(frozen=True) @@ -100,7 +101,12 @@ async def __aenter__(self) -> None: await self.acquire() # type: ignore[attr-defined] @enable_ki_protection - async def __aexit__(self, *args: object) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.release() # type: ignore[attr-defined] diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 0206f1f737..6d8da15940 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -1,4 +1,5 @@ # Generic stream tests +from __future__ import annotations from contextlib import contextmanager import random @@ -7,6 +8,10 @@ from .._highlevel_generic import aclose_forcefully from .._abc import SendStream, ReceiveStream, Stream, HalfCloseableStream from ._checkpoints import assert_checkpoints +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import TracebackType class _ForceCloseBoth: @@ -16,7 +21,12 @@ def __init__(self, both): async def __aenter__(self): return self._both - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: try: await aclose_forcefully(self._both[0]) finally: From ba6456557e5b3c88fd79d082d1508dacc15b1fb1 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 4 Jul 2023 16:03:10 +0200 Subject: [PATCH 05/21] rename exc_val to exc_value to match official data model --- trio/_abc.py | 2 +- trio/_channel.py | 4 ++-- trio/_core/_multierror.py | 8 ++++---- trio/_core/_run.py | 2 +- trio/_dtls.py | 4 ++-- trio/_socket.py | 8 ++++---- trio/_sync.py | 2 +- trio/_util.py | 2 +- trio/testing/_check_streams.py | 2 +- trio/testing/_fake_net.py | 2 +- 10 files changed, 18 insertions(+), 18 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 21577310d6..c20774dcad 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -264,7 +264,7 @@ async def __aenter__(self) -> AsyncResource: async def __aexit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.aclose() diff --git a/trio/_channel.py b/trio/_channel.py index 3ad08b7109..f2e63363b4 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -249,7 +249,7 @@ def __enter__(self: SelfT) -> SelfT: def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.close() @@ -395,7 +395,7 @@ def __enter__(self: SelfT) -> SelfT: def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.close() diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 6e964ba038..8abd9730fb 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -137,13 +137,13 @@ def __enter__(self): def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - if exc_val is not None: - filtered_exc = _filter_impl(self._handler, exc_val) + if exc_value is not None: + filtered_exc = _filter_impl(self._handler, exc_value) - if filtered_exc is exc_val: + if filtered_exc is exc_value: # Let the interpreter re-raise it return False if filtered_exc is None: diff --git a/trio/_core/_run.py b/trio/_core/_run.py index b7a8359698..03799c1d2c 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -906,7 +906,7 @@ def __enter__(self) -> None: def __exit__( self, # pragma: no cover exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: assert False, """Never called, but should be defined""" diff --git a/trio/_dtls.py b/trio/_dtls.py index 9cb8a5d03c..693759ef7a 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -841,7 +841,7 @@ def __enter__(self): def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: return self.close() @@ -1199,7 +1199,7 @@ def __enter__(self): def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: return self.close() diff --git a/trio/_socket.py b/trio/_socket.py index 8e98ce8fa3..ff57a447cd 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -42,10 +42,10 @@ async def __aenter__(self): async def __aexit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> bool: - if exc_val is not None and self._is_blocking_io_error(exc_val): + if exc_value is not None and self._is_blocking_io_error(exc_value): # Discard the exception and fall through to the code below the # block return True @@ -494,10 +494,10 @@ def __enter__(self) -> Self: def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: - return self._sock.__exit__(exc_type, exc_val, exc_tb) + return self._sock.__exit__(exc_type, exc_value, exc_tb) @property def family(self) -> _stdlib_socket.AddressFamily: diff --git a/trio/_sync.py b/trio/_sync.py index 164444771f..75898b5ef2 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -104,7 +104,7 @@ async def __aenter__(self) -> None: async def __aexit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.release() # type: ignore[attr-defined] diff --git a/trio/_util.py b/trio/_util.py index 090204f6e3..93dd1e38d0 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -194,7 +194,7 @@ def __enter__(self): def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: self._held = False diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 6d8da15940..a9a7125a48 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -24,7 +24,7 @@ async def __aenter__(self): async def __aexit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: try: diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 26348f2de2..3b340d2a93 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -344,7 +344,7 @@ def __enter__(self): def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, + exc_value: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.close() From adb87822380812675b3ba77888e93867e1bb7807 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 4 Jul 2023 16:04:23 +0200 Subject: [PATCH 06/21] rename exc_tb to traceback to match python data model --- trio/_abc.py | 2 +- trio/_channel.py | 4 ++-- trio/_core/_multierror.py | 2 +- trio/_core/_run.py | 2 +- trio/_dtls.py | 4 ++-- trio/_socket.py | 6 +++--- trio/_sync.py | 2 +- trio/_util.py | 2 +- trio/testing/_check_streams.py | 2 +- trio/testing/_fake_net.py | 2 +- 10 files changed, 14 insertions(+), 14 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index c20774dcad..3a0ece749a 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -265,7 +265,7 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: await self.aclose() diff --git a/trio/_channel.py b/trio/_channel.py index f2e63363b4..a7c13dc9bd 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -250,7 +250,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: self.close() @@ -396,7 +396,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: self.close() diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 8abd9730fb..584179f595 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -138,7 +138,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> bool | None: if exc_value is not None: filtered_exc = _filter_impl(self._handler, exc_value) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 03799c1d2c..8f2264f52a 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -907,7 +907,7 @@ def __exit__( self, # pragma: no cover exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> bool: assert False, """Never called, but should be defined""" diff --git a/trio/_dtls.py b/trio/_dtls.py index 693759ef7a..ad6e492b97 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -842,7 +842,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: return self.close() @@ -1200,7 +1200,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: return self.close() diff --git a/trio/_socket.py b/trio/_socket.py index ff57a447cd..6e9abe0cac 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -43,7 +43,7 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> bool: if exc_value is not None and self._is_blocking_io_error(exc_value): # Discard the exception and fall through to the code below the @@ -495,9 +495,9 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: - return self._sock.__exit__(exc_type, exc_value, exc_tb) + return self._sock.__exit__(exc_type, exc_value, traceback) @property def family(self) -> _stdlib_socket.AddressFamily: diff --git a/trio/_sync.py b/trio/_sync.py index 75898b5ef2..3fdc38b302 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -105,7 +105,7 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: self.release() # type: ignore[attr-defined] diff --git a/trio/_util.py b/trio/_util.py index 93dd1e38d0..9631226e88 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -195,7 +195,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: self._held = False diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index a9a7125a48..0e1a4f8f14 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -25,7 +25,7 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: try: await aclose_forcefully(self._both[0]) diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 3b340d2a93..646b7af7a6 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -345,7 +345,7 @@ def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, - exc_tb: TracebackType | None, + traceback: TracebackType | None, ) -> None: self.close() From 3c9583c82905d52763fd9a01548a8d204ee7d89e Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 4 Jul 2023 16:11:17 +0200 Subject: [PATCH 07/21] fix format check CI --- trio/_socket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_socket.py b/trio/_socket.py index 6e9abe0cac..0446e998ae 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -5,7 +5,7 @@ import select import socket as _stdlib_socket from functools import wraps as _wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import idna as _idna From 0f29a1a34364e2588227e4128f9153efb0d07bd5 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 5 Jul 2023 17:58:04 +0200 Subject: [PATCH 08/21] changes after review from TeamSpen210 --- trio/_core/_run.py | 25 ++++++++++++++----------- trio/_highlevel_socket.py | 4 ++-- trio/_socket.py | 8 +++++++- trio/_tests/verify_types.json | 5 ++--- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 8f2264f52a..7fdd203f9c 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -898,18 +898,21 @@ async def __aexit__( # see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage del _, combined_error_from_nursery, value, new_exc - def __enter__(self) -> None: - raise RuntimeError( - "use 'async with open_nursery(...)', not 'with open_nursery(...)'" - ) + # make sure these raise errors in static analysis if called + if not TYPE_CHECKING: - def __exit__( - self, # pragma: no cover - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> bool: - assert False, """Never called, but should be defined""" + def __enter__(self) -> NoReturn: + raise RuntimeError( + "use 'async with open_nursery(...)', not 'with open_nursery(...)'" + ) + + def __exit__( + self, # pragma: no cover + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> NoReturn: + raise AssertionError("Never called, but should be defined") def open_nursery( diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index b3641eb683..b63902d887 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -3,7 +3,7 @@ import errno from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import trio from . import socket as tsocket @@ -130,7 +130,7 @@ async def send_eof(self) -> None: with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes: int | None = None) -> Any: + async def receive_some(self, max_bytes: int | None = None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: diff --git a/trio/_socket.py b/trio/_socket.py index 0446e998ae..8c70a0a3a1 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -701,7 +701,13 @@ async def connect(self, address): # recv ################################################################ - recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) + if TYPE_CHECKING: + + async def recv(self, buffersize: int, flags: int = 0) -> bytes: + ... + + else: + recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) ################################################################ # recv_into diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 791c416e49..a9bfc663c1 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 430, - "withUnknownType": 138 + "withKnownType": 431, + "withUnknownType": 137 }, "packageName": "trio", "symbols": [ @@ -60,7 +60,6 @@ "trio._socket._SocketType.__getattr__", "trio._socket._SocketType.accept", "trio._socket._SocketType.connect", - "trio._socket._SocketType.recv", "trio._socket._SocketType.recv_into", "trio._socket._SocketType.recvfrom", "trio._socket._SocketType.recvfrom_into", From c2012a79215c8f7fc8d1a9b61a5508b36ee67844 Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Thu, 6 Jul 2023 15:16:13 +0200 Subject: [PATCH 09/21] Apply suggestions from code review Co-authored-by: Spencer Brown --- trio/_core/_parking_lot.py | 9 +++++++-- trio/_core/_run.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index ecf22652d8..bfcf08a69a 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -137,8 +137,13 @@ def abort_fn(_): def _pop_several(self, count: int | float) -> Iterator[Task]: if isinstance(count, float): - assert math.isinf(count), "Cannot pop a non-integer number of tasks." - for _ in range(cast(int, min(count, len(self._parked)))): + if math.isinf(count): + count = len(self._parked) + else: + raise ValueError("Cannot pop a non-integer number of tasks.") + else: + count = min(count, len(self._parked)) + for _ in range(count): task, _ = self._parked.popitem(last=False) yield task diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 7fdd203f9c..a3fec21c6f 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -861,7 +861,7 @@ class NurseryManager: strict_exception_groups: bool = attr.ib(default=False) @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> Nursery: self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create( From 9a0c3e28281dd21485d3459ee47b2365ab1a1eb8 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 6 Jul 2023 15:20:00 +0200 Subject: [PATCH 10/21] fixes after review --- trio/_core/_run.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index a3fec21c6f..ddf7221a6f 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -50,9 +50,7 @@ if TYPE_CHECKING: # An unfortunate name collision here with trio._util.Final from typing_extensions import Final as FinalT - from collections.abc import Coroutine from types import FrameType - import outcome import contextvars DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 @@ -1169,7 +1167,10 @@ def __del__(self) -> None: @attr.s(eq=False, hash=False, repr=False, slots=True) class Task(metaclass=NoPublicConstructor): _parent_nursery: Nursery | None = attr.ib() - coro: Coroutine[Any, outcome.Outcome[object], Any] = attr.ib() + # could be typed as `Coroutine[Any, outcome.Outcome[object], Any]` but + # the code does a lot of dynamic introspection on it and as long as it passes + # those it's fine. + coro: Any = attr.ib() _runner = attr.ib() name: str = attr.ib() context: contextvars.Context = attr.ib() @@ -1191,8 +1192,10 @@ class Task(metaclass=NoPublicConstructor): _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( default=None ) - # could possible be set with a TypeVar - custom_sleep_data: Any = attr.ib(default=None) + # Typed as `object`, forcing users to do an isinstance check each time. Since + # anything touching the task could have set this, it's not really going to be + # safe to assume that this had the value you saw it with last. + custom_sleep_data: object = attr.ib(default=None) # For introspection and nursery.start() _child_nurseries: list[Nursery] = attr.ib(factory=list) @@ -1264,11 +1267,11 @@ def print_stack_for_task(task): if hasattr(coro, "cr_frame"): # A real coroutine yield coro.cr_frame, coro.cr_frame.f_lineno - coro = coro.cr_await # type: ignore # TODO + coro = coro.cr_await elif hasattr(coro, "gi_frame"): # A generator decorated with @types.coroutine yield coro.gi_frame, coro.gi_frame.f_lineno - coro = coro.gi_yieldfrom # type: ignore # TODO + coro = coro.gi_yieldfrom elif coro.__class__.__name__ in [ "async_generator_athrow", "async_generator_asend", From 3094d066e25ad60b690c7611c1cf1395ecfce168 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 6 Jul 2023 16:14:24 +0200 Subject: [PATCH 11/21] attempt to fix readthedocs build, (temp) ignore errors on custom_sleep_data --- trio/__init__.py | 4 ++++ trio/_channel.py | 8 ++++---- trio/_core/__init__.py | 2 +- trio/_core/_run.py | 3 ++- trio/_tests/verify_types.json | 4 ++-- trio/lowlevel.py | 1 + 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/trio/__init__.py b/trio/__init__.py index 40aa3c430d..728fe4e0d7 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -50,11 +50,15 @@ from ._sync import ( Event as Event, + EventStatistics as EventStatistics, CapacityLimiter as CapacityLimiter, Semaphore as Semaphore, Lock as Lock, StrictFIFOLock as StrictFIFOLock, Condition as Condition, + ConditionStatistics as ConditionStatistics, + CapacityLimiterStatistics as CapacityLimiterStatistics, + LockStatistics as LockStatistics, ) from ._highlevel_generic import ( diff --git a/trio/_channel.py b/trio/_channel.py index a7c13dc9bd..05a685c3c9 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -178,7 +178,7 @@ def send_nowait(self, value: SendType) -> None: if self._state.receive_tasks: assert not self._state.data task, _ = self._state.receive_tasks.popitem(last=False) - task.custom_sleep_data._tasks.remove(task) + task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] trio.lowlevel.reschedule(task, Value(value)) elif len(self._state.data) < self._state.max_buffer_size: self._state.data.append(value) @@ -278,7 +278,7 @@ def close(self) -> None: if self._state.open_send_channels == 0: assert not self._state.send_tasks for task in self._state.receive_tasks: - task.custom_sleep_data._tasks.remove(task) + task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] trio.lowlevel.reschedule(task, Error(trio.EndOfChannel())) self._state.receive_tasks.clear() @@ -315,7 +315,7 @@ def receive_nowait(self) -> ReceiveType: raise trio.ClosedResourceError if self._state.send_tasks: task, value = self._state.send_tasks.popitem(last=False) - task.custom_sleep_data._tasks.remove(task) + task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] trio.lowlevel.reschedule(task) self._state.data.append(value) # Fall through @@ -424,7 +424,7 @@ def close(self) -> None: if self._state.open_receive_channels == 0: assert not self._state.receive_tasks for task in self._state.send_tasks: - task.custom_sleep_data._tasks.remove(task) + task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError())) self._state.send_tasks.clear() self._state.data.clear() diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index f9919b8323..a57f2e7021 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -64,7 +64,7 @@ from ._entry_queue import TrioToken -from ._parking_lot import ParkingLot +from ._parking_lot import ParkingLot, ParkingLotStatistics from ._unbounded_queue import UnboundedQueue diff --git a/trio/_core/_run.py b/trio/_core/_run.py index ddf7221a6f..2e1905029f 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -47,10 +47,11 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +from types import FrameType + if TYPE_CHECKING: # An unfortunate name collision here with trio._util.Final from typing_extensions import Final as FinalT - from types import FrameType import contextvars DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index a9bfc663c1..c1d4aab66a 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,10 +7,10 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8737864077669902, + "completenessScore": 0.8747993579454254, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 540, + "withKnownType": 545, "withUnknownType": 77 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/lowlevel.py b/trio/lowlevel.py index b7f4f3a725..55838fb8f6 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -25,6 +25,7 @@ checkpoint as checkpoint, current_task as current_task, ParkingLot as ParkingLot, + ParkingLotStatistics as ParkingLotStatistics, UnboundedQueue as UnboundedQueue, RunVar as RunVar, TrioToken as TrioToken, From f83373b9d2d984496ee2d61ec7c6cfd69386ae01 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 12:29:54 +0200 Subject: [PATCH 12/21] revent custom_sleep_data to Any, add docstring to some Statistics to see if that makes readthedocs happy --- trio/_channel.py | 8 ++++---- trio/_core/_parking_lot.py | 11 ++++++++++- trio/_core/_run.py | 2 +- trio/_sync.py | 22 ++++++++++++++++++++++ 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/trio/_channel.py b/trio/_channel.py index 05a685c3c9..a7c13dc9bd 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -178,7 +178,7 @@ def send_nowait(self, value: SendType) -> None: if self._state.receive_tasks: assert not self._state.data task, _ = self._state.receive_tasks.popitem(last=False) - task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] + task.custom_sleep_data._tasks.remove(task) trio.lowlevel.reschedule(task, Value(value)) elif len(self._state.data) < self._state.max_buffer_size: self._state.data.append(value) @@ -278,7 +278,7 @@ def close(self) -> None: if self._state.open_send_channels == 0: assert not self._state.send_tasks for task in self._state.receive_tasks: - task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] + task.custom_sleep_data._tasks.remove(task) trio.lowlevel.reschedule(task, Error(trio.EndOfChannel())) self._state.receive_tasks.clear() @@ -315,7 +315,7 @@ def receive_nowait(self) -> ReceiveType: raise trio.ClosedResourceError if self._state.send_tasks: task, value = self._state.send_tasks.popitem(last=False) - task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] + task.custom_sleep_data._tasks.remove(task) trio.lowlevel.reschedule(task) self._state.data.append(value) # Fall through @@ -424,7 +424,7 @@ def close(self) -> None: if self._state.open_receive_channels == 0: assert not self._state.receive_tasks for task in self._state.send_tasks: - task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined] + task.custom_sleep_data._tasks.remove(task) trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError())) self._state.send_tasks.clear() self._state.data.clear() diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index bfcf08a69a..272dfa2e1c 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -73,7 +73,7 @@ import attr from collections import OrderedDict -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from collections.abc import Iterator import math @@ -86,6 +86,15 @@ @attr.s(frozen=True, slots=True) class ParkingLotStatistics: + """An object containing debugging information for a ParkingLot. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this lot's + :meth:`park` method. + + """ + tasks_waiting: int = attr.ib() diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 2e1905029f..12c1f8581e 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1196,7 +1196,7 @@ class Task(metaclass=NoPublicConstructor): # Typed as `object`, forcing users to do an isinstance check each time. Since # anything touching the task could have set this, it's not really going to be # safe to assume that this had the value you saw it with last. - custom_sleep_data: object = attr.ib(default=None) + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() _child_nurseries: list[Nursery] = attr.ib(factory=list) diff --git a/trio/_sync.py b/trio/_sync.py index 3fdc38b302..e43e646065 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -484,6 +484,18 @@ def statistics(self) -> ParkingLotStatistics: @attr.s(frozen=True) class LockStatistics: + """An object containing debugging information for a Lock. + + Currently the following fields are defined: + + * ``locked`` (boolean): indicating whether the lock is held. + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, + or None if the lock is not held. + * ``tasks_waiting`` (int): The number of tasks blocked on this lock's + :meth:`acquire` method. + + """ + locked: bool = attr.ib() owner: Task | None = attr.ib() tasks_waiting: int = attr.ib() @@ -658,6 +670,16 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): @attr.s(frozen=True) class ConditionStatistics: + r"""Return an object containing debugging information for a Condition. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this condition's + :meth:`wait` method. + * ``lock_statistics``: The result of calling the underlying + :class:`Lock`\s :meth:`~Lock.statistics` method. + + """ tasks_waiting: int = attr.ib() lock_statistics: LockStatistics = attr.ib() From ac4b357a86c86cd8f129fe2cb3fbedf53a2b7249 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 14:45:49 +0200 Subject: [PATCH 13/21] fix CI errors --- docs/source/conf.py | 12 ++++++++++++ docs/source/reference-core.rst | 12 ++++++++++++ docs/source/reference-lowlevel.rst | 2 ++ trio/_core/_parking_lot.py | 9 +++++---- trio/_sync.py | 19 ++++++++++--------- 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index cfac66576b..2d8e2e8111 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,10 +62,22 @@ ("py:obj", "trio._abc.SendType"), ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), + ("py:class", "types.FrameType"), ] autodoc_inherit_docstrings = False default_role = "obj" +# These have incorrect __module__ set in stdlib and give the error +# `py:class reference target not found` +# Some of the nitpick_ignore's above can probably be fixed with this. +# See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798 +autodoc_type_aliases = { + # aliasing doesn't actually fix the warning for types.FrameType, but displaying + # "types.FrameType" is more helpfun than just "frame" + "FrameType": "types.FrameType", +} + + # XX hack the RTD theme until # https://github.com/rtfd/sphinx_rtd_theme/pull/382 # is shipped (should be in the release after 0.2.4) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 922ae4680e..4f4f4d62b9 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1096,6 +1096,8 @@ Broadcasting an event with :class:`Event` .. autoclass:: Event :members: +.. autoclass:: EventStatistics + :members: .. _channels: @@ -1456,6 +1458,16 @@ don't have any special access to Trio's internals.) .. autoclass:: Condition :members: +These primitives return statistics objects that can be inspected. + +.. autoclass:: CapacityLimiterStatistics + :members: + +.. autoclass:: LockStatistics + :members: + +.. autoclass:: ConditionStatistics + :members: .. _async-generators: diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 815cff2ddf..bacebff5ad 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -378,6 +378,8 @@ Wait queue abstraction :members: :undoc-members: +.. autoclass:: ParkingLotStatistics + :members: Low-level checkpoint functions ------------------------------ diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 272dfa2e1c..da10e7cf91 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -71,11 +71,12 @@ # See: https://github.com/python-trio/trio/issues/53 from __future__ import annotations -import attr +import math from collections import OrderedDict -from typing import TYPE_CHECKING from collections.abc import Iterator -import math +from typing import TYPE_CHECKING + +import attr from .. import _core from .._util import Final @@ -91,7 +92,7 @@ class ParkingLotStatistics: Currently the following fields are defined: * ``tasks_waiting`` (int): The number of tasks blocked on this lot's - :meth:`park` method. + :meth:`trio.lowlevel.ParkingLot.park` method. """ diff --git a/trio/_sync.py b/trio/_sync.py index e43e646065..feb5825e47 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,23 +1,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING import math +from typing import TYPE_CHECKING import attr import trio from . import _core -from ._core import enable_ki_protection, ParkingLot +from ._core import ParkingLot, enable_ki_protection from ._util import Final if TYPE_CHECKING: + from types import TracebackType + from ._core import Task from ._core._parking_lot import ParkingLotStatistics - from types import TracebackType -@attr.s(frozen=True) +@attr.s(frozen=True, slots=True) class EventStatistics: tasks_waiting: int = attr.ib() @@ -110,7 +111,7 @@ async def __aexit__( self.release() # type: ignore[attr-defined] -@attr.s(frozen=True) +@attr.s(frozen=True, slots=True) class CapacityLimiterStatistics: borrowed_tokens: int = attr.ib() total_tokens: int | float = attr.ib() @@ -482,7 +483,7 @@ def statistics(self) -> ParkingLotStatistics: return self._lot.statistics() -@attr.s(frozen=True) +@attr.s(frozen=True, slots=True) class LockStatistics: """An object containing debugging information for a Lock. @@ -492,7 +493,7 @@ class LockStatistics: * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, or None if the lock is not held. * ``tasks_waiting`` (int): The number of tasks blocked on this lock's - :meth:`acquire` method. + :meth:`trio.Lock.acquire` method. """ @@ -668,14 +669,14 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): """ -@attr.s(frozen=True) +@attr.s(frozen=True, slots=True) class ConditionStatistics: r"""Return an object containing debugging information for a Condition. Currently the following fields are defined: * ``tasks_waiting`` (int): The number of tasks blocked on this condition's - :meth:`wait` method. + :meth:`trio.Condition.wait` method. * ``lock_statistics``: The result of calling the underlying :class:`Lock`\s :meth:`~Lock.statistics` method. From 77a370636dde766ec4f6d742c67ca5ef565d8db6 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 14:51:27 +0200 Subject: [PATCH 14/21] don't require *Statistics to be Final --- trio/_tests/test_exports.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 3ab0016386..e51bbe31f5 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -492,4 +492,8 @@ def test_classes_are_final(): continue # ... insert other special cases here ... + # don't care about the *Statistics classes + if name.endswith("Statistics"): + continue + assert isinstance(class_, _util.Final) From 50ce72fb0b9b4536659eda2a23cb986ab25adeb7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 15:00:33 +0200 Subject: [PATCH 15/21] fix CI --- trio/_core/_multierror.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 584179f595..3c6ebb789f 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,11 +1,12 @@ from __future__ import annotations + import sys import warnings +from typing import TYPE_CHECKING import attr from trio._deprecate import warn_deprecated -from typing import TYPE_CHECKING if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup, print_exception From f34787be6e286c2bcae3c70ac31b792eb2173eba Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 15:25:47 +0200 Subject: [PATCH 16/21] fix codecov and formatting --- trio/_core/_run.py | 4 ++-- trio/_core/_tests/test_parking_lot.py | 3 +++ trio/_dtls.py | 2 +- trio/_util.py | 4 +--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 1e9d309003..4ac3f93339 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -904,8 +904,8 @@ def __enter__(self) -> NoReturn: "use 'async with open_nursery(...)', not 'with open_nursery(...)'" ) - def __exit__( - self, # pragma: no cover + def __exit__( # pragma: no cover + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py index db3fc76709..32442e5a0f 100644 --- a/trio/_core/_tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -72,6 +72,9 @@ async def waiter(i, lot): ) lot.unpark_all() + with pytest.raises(TypeError): + lot.unpark(count=1.5) + async def cancellable_waiter(name, lot, scopes, record): with _core.CancelScope() as scope: diff --git a/trio/_dtls.py b/trio/_dtls.py index 57ce638a5d..da99481e06 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -1185,7 +1185,7 @@ def close(self) -> None: """ # Do nothing if this object was never fully constructed - if self.socket is None: + if self.socket is None: # pragma: no cover return self._closed = True diff --git a/trio/_util.py b/trio/_util.py index cbe892f41d..0a0795fc15 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -9,12 +9,10 @@ import typing as t from abc import ABCMeta from functools import update_wrapper +from types import TracebackType import trio -if t.TYPE_CHECKING: - from types import TracebackType - # Equivalent to the C function raise(), which Python doesn't wrap if os.name == "nt": # On Windows, os.kill exists but is really weird. From a8751bbc5f5dd345dfb8c88e50b0feff4c1b3cd4 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 15:38:23 +0200 Subject: [PATCH 17/21] ValueError, not TypeError --- trio/_core/_tests/test_parking_lot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py index 32442e5a0f..3f03fdbade 100644 --- a/trio/_core/_tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -72,7 +72,7 @@ async def waiter(i, lot): ) lot.unpark_all() - with pytest.raises(TypeError): + with pytest.raises(ValueError): lot.unpark(count=1.5) From ed05e130bc4abdb6e10c0a7605b9041219628538 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 7 Jul 2023 16:06:55 +0200 Subject: [PATCH 18/21] fix pragma: no cover --- trio/_core/_run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 4ac3f93339..49f55dee3e 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -904,12 +904,12 @@ def __enter__(self) -> NoReturn: "use 'async with open_nursery(...)', not 'with open_nursery(...)'" ) - def __exit__( # pragma: no cover + def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, - ) -> NoReturn: + ) -> NoReturn: # pragma: no cover raise AssertionError("Never called, but should be defined") From 67166a5cc287fe9d32ce80f2d4340bacbfb64993 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 10 Jul 2023 12:31:19 +0200 Subject: [PATCH 19/21] fixes from CI review, mainly from Zac-HD --- docs/source/conf.py | 2 +- trio/_core/_run.py | 19 ++++++++++--------- trio/_highlevel_generic.py | 3 ++- trio/_socket.py | 4 ++-- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2d8e2e8111..68a5a22a81 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,7 @@ # See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798 autodoc_type_aliases = { # aliasing doesn't actually fix the warning for types.FrameType, but displaying - # "types.FrameType" is more helpfun than just "frame" + # "types.FrameType" is more helpful than just "frame" "FrameType": "types.FrameType", } diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 49f55dee3e..fa7d651da1 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -10,7 +10,7 @@ import threading import warnings from collections import deque -from collections.abc import Callable, Iterator +from collections.abc import Callable, Coroutine, Iterator from contextlib import AbstractAsyncContextManager, contextmanager from contextvars import copy_context from heapq import heapify, heappop, heappush @@ -1167,10 +1167,7 @@ def __del__(self) -> None: @attr.s(eq=False, hash=False, repr=False, slots=True) class Task(metaclass=NoPublicConstructor): _parent_nursery: Nursery | None = attr.ib() - # could be typed as `Coroutine[Any, outcome.Outcome[object], Any]` but - # the code does a lot of dynamic introspection on it and as long as it passes - # those it's fine. - coro: Any = attr.ib() + coro: Coroutine[Any, Outcome[object], Any] = attr.ib() _runner = attr.ib() name: str = attr.ib() context: contextvars.Context = attr.ib() @@ -1262,7 +1259,8 @@ def print_stack_for_task(task): print("".join(ss.format())) """ - coro = self.coro + # ignore static typing as we're doing lots of dynamic introspection + coro: Any = self.coro while coro is not None: if hasattr(coro, "cr_frame"): # A real coroutine @@ -1297,7 +1295,7 @@ def print_stack_for_task(task): # Don't change this directly; instead, use _activate_cancel_status(). _cancel_status: CancelStatus = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status: CancelStatus): + def _activate_cancel_status(self, cancel_status: CancelStatus) -> None: if self._cancel_status is not None: self._cancel_status._tasks.remove(self) self._cancel_status = cancel_status @@ -1312,8 +1310,11 @@ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None: # rescheduling itself (hopefully eventually calling reraise to raise # the given exception, but not necessarily). - # TODO: what happens if _abort_func is None? - success = self._abort_func(raise_cancel) # type: ignore + # This is only called by the functions immediately below, which both check + # `self.abort_func is not None`. + assert self._abort_func is not None, "FATAL INTERNAL ERROR" + + success = self._abort_func(raise_cancel) if type(success) is not Abort: raise TrioInternalError("abort function must return Abort enum") # We only attempt to abort once per blocking call, regardless of diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index 6f98d6e18a..c126d4c8aa 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -102,7 +102,8 @@ async def send_eof(self) -> None: else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: + # we intentionally accept more types from the caller than we support returning + async def receive_some(self, max_bytes: int | None = None) -> bytes: """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) diff --git a/trio/_socket.py b/trio/_socket.py index 68dbb87c48..eaf0e04d15 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -487,8 +487,7 @@ def __getattr__(self, name): raise AttributeError(name) def __dir__(self) -> Iterable[str]: - yield from super().__dir__() - yield from self._forward + return [*super().__dir__(), *self._forward] def __enter__(self) -> Self: return self @@ -538,6 +537,7 @@ async def bind(self, address: tuple[object, ...] | str | bytes) -> None: ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) + # remove the `type: ignore` when run.sync is typed. return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return] else: # POSIX actually says that bind can return EWOULDBLOCK and From 89f308f2dade05ec268dd4687f5e12033e0049a4 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 11 Jul 2023 13:21:21 +0200 Subject: [PATCH 20/21] small fixes after review from A5Rocks --- trio/_abc.py | 2 +- trio/_core/_parking_lot.py | 4 ++-- trio/_core/_run.py | 7 ++----- trio/_dtls.py | 2 ++ trio/_highlevel_generic.py | 4 +++- trio/_sync.py | 30 +++++++++++++++++++++++++++++- trio/_tests/verify_types.json | 11 +++++------ trio/lowlevel.py | 1 - 8 files changed, 44 insertions(+), 17 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index c88b5f559e..2a1721db13 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -260,7 +260,7 @@ async def aclose(self) -> None: """ - async def __aenter__(self) -> AsyncResource: + async def __aenter__(self) -> Self: return self async def __aexit__( diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index da10e7cf91..74708433da 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -179,7 +179,7 @@ def unpark_all(self) -> list[Task]: return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot: ParkingLot, *, count: int = 1) -> None: + def repark(self, new_lot: ParkingLot, *, count: int | float = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on @@ -209,7 +209,7 @@ async def main(): Args: new_lot (ParkingLot): the parking lot to move tasks to. - count (int): the number of tasks to move. + count (int|math.inf): the number of tasks to move. """ if not isinstance(new_lot, ParkingLot): diff --git a/trio/_core/_run.py b/trio/_core/_run.py index fa7d651da1..585dc4aa41 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -48,9 +48,9 @@ from types import FrameType if TYPE_CHECKING: - # An unfortunate name collision here with trio._util.Final import contextvars + # An unfortunate name collision here with trio._util.Final from typing_extensions import Final as FinalT DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 @@ -1189,9 +1189,6 @@ class Task(metaclass=NoPublicConstructor): _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( default=None ) - # Typed as `object`, forcing users to do an isinstance check each time. Since - # anything touching the task could have set this, it's not really going to be - # safe to assume that this had the value you saw it with last. custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() @@ -2468,7 +2465,7 @@ class _TaskStatusIgnored: def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value: Any = None): + def started(self, value: object = None) -> None: pass diff --git a/trio/_dtls.py b/trio/_dtls.py index da99481e06..722a9499f8 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -1132,6 +1132,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10): global SSL from OpenSSL import SSL + # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed + # as trio.socket.SocketType and `is not None` checks can be removed. self.socket = None # for __del__, in case the next line raises if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index c126d4c8aa..e1ac378c6a 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -98,7 +98,9 @@ async def send_eof(self) -> None: """ if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() # type: ignore + # send_stream.send_eof() is not defined in Trio, this should maybe be + # redesigned so it's possible to type it. + return await self.send_stream.send_eof() # type: ignore[no-any-return] else: return await self.send_stream.aclose() diff --git a/trio/_sync.py b/trio/_sync.py index feb5825e47..1ccff947a0 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -20,6 +20,15 @@ @attr.s(frozen=True, slots=True) class EventStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this event's + :meth:`wait` method. + + """ + tasks_waiting: int = attr.ib() @@ -96,6 +105,8 @@ def statistics(self) -> EventStatistics: return EventStatistics(tasks_waiting=len(self._tasks)) +# TODO: type this with a Protocol to get rid of type: ignore, see +# https://github.com/python-trio/trio/pull/2682#discussion_r1259097422 class AsyncContextManagerMixin: @enable_ki_protection async def __aenter__(self) -> None: @@ -113,6 +124,23 @@ async def __aexit__( @attr.s(frozen=True, slots=True) class CapacityLimiterStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`acquire` or + :meth:`acquire_on_behalf_of` methods. + + """ + borrowed_tokens: int = attr.ib() total_tokens: int | float = attr.ib() borrowers: list[Task] = attr.ib() @@ -671,7 +699,7 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): @attr.s(frozen=True, slots=True) class ConditionStatistics: - r"""Return an object containing debugging information for a Condition. + r"""An object containing debugging information for a Condition. Currently the following fields are defined: diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index cb2253ec27..9d7d7aa912 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8747993579454254, + "completenessScore": 0.8764044943820225, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 545, - "withUnknownType": 77 + "withKnownType": 546, + "withUnknownType": 76 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 431, - "withUnknownType": 137 + "withKnownType": 433, + "withUnknownType": 135 }, "packageName": "trio", "symbols": [ @@ -73,7 +73,6 @@ "trio._core._mock_clock.MockClock.jump", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", - "trio._core._run._TaskStatusIgnored.started", "trio._core._unbounded_queue.UnboundedQueue.__aiter__", "trio._core._unbounded_queue.UnboundedQueue.__anext__", "trio._core._unbounded_queue.UnboundedQueue.__repr__", diff --git a/trio/lowlevel.py b/trio/lowlevel.py index 4b0f114b2d..54f4ef3141 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -7,7 +7,6 @@ import sys import typing as _t -# Generally available symbols # Generally available symbols from ._core import ( Abort as Abort, From 253684326ca071ea8d8190d2b32d2d5fb006bb2f Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 11 Jul 2023 13:26:25 +0200 Subject: [PATCH 21/21] fix read the docs references --- trio/_sync.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trio/_sync.py b/trio/_sync.py index 1ccff947a0..5a7f240d5e 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -25,7 +25,7 @@ class EventStatistics: Currently the following fields are defined: * ``tasks_waiting``: The number of tasks blocked on this event's - :meth:`wait` method. + :meth:`trio.Event.wait` method. """ @@ -132,12 +132,12 @@ class CapacityLimiterStatistics: the sack. * ``total_tokens``: The total number of tokens in the sack. Usually this will be larger than ``borrowed_tokens``, but it's possibly for - it to be smaller if :attr:`total_tokens` was recently decreased. + it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased. * ``borrowers``: A list of all tasks or other entities that currently hold a token. * ``tasks_waiting``: The number of tasks blocked on this - :class:`CapacityLimiter`\'s :meth:`acquire` or - :meth:`acquire_on_behalf_of` methods. + :class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or + :meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods. """