From af601ac29615f1ce61872a3c76c373f2c6e10a0b Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 3 Jul 2023 16:53:58 +0200
Subject: [PATCH 01/21] lots of typing improvements
---
trio/_abc.py | 28 ++++----
trio/_core/_mock_clock.py | 26 +++----
trio/_core/_parking_lot.py | 36 ++++++----
trio/_core/_run.py | 132 +++++++++++++++++++---------------
trio/_highlevel_generic.py | 23 +++---
trio/_highlevel_socket.py | 21 +++---
trio/_socket.py | 43 ++++++-----
trio/_sync.py | 132 ++++++++++++++++++----------------
trio/_tests/verify_types.json | 126 +++++---------------------------
9 files changed, 266 insertions(+), 301 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index c085c82b89..403e07a1de 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from abc import ABCMeta, abstractmethod
from typing import Generic, TypeVar
import trio
@@ -11,7 +13,7 @@ class Clock(metaclass=ABCMeta):
__slots__ = ()
@abstractmethod
- def start_clock(self):
+ def start_clock(self) -> None:
"""Do any setup this clock might need.
Called at the beginning of the run.
@@ -19,7 +21,7 @@ def start_clock(self):
"""
@abstractmethod
- def current_time(self):
+ def current_time(self) -> float:
"""Return the current time, according to this clock.
This is used to implement functions like :func:`trio.current_time` and
@@ -31,7 +33,7 @@ def current_time(self):
"""
@abstractmethod
- def deadline_to_sleep_time(self, deadline):
+ def deadline_to_sleep_time(self, deadline: float) -> float:
"""Compute the real time until the given deadline.
This is called before we enter a system-specific wait function like
@@ -224,7 +226,7 @@ class AsyncResource(metaclass=ABCMeta):
__slots__ = ()
@abstractmethod
- async def aclose(self):
+ async def aclose(self) -> None:
"""Close this resource, possibly blocking.
IMPORTANT: This method may block in order to perform a "graceful"
@@ -252,10 +254,10 @@ async def aclose(self):
"""
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncResource:
return self
- async def __aexit__(self, *args):
+ async def __aexit__(self, *args: object) -> None:
await self.aclose()
@@ -278,7 +280,7 @@ class SendStream(AsyncResource):
__slots__ = ()
@abstractmethod
- async def send_all(self, data):
+ async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Sends the given data through the stream, blocking if necessary.
Args:
@@ -304,7 +306,7 @@ async def send_all(self, data):
"""
@abstractmethod
- async def wait_send_all_might_not_block(self):
+ async def wait_send_all_might_not_block(self) -> None:
"""Block until it's possible that :meth:`send_all` might not block.
This method may return early: it's possible that after it returns,
@@ -384,7 +386,7 @@ class ReceiveStream(AsyncResource):
__slots__ = ()
@abstractmethod
- async def receive_some(self, max_bytes=None):
+ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
"""Wait until there is data available on this stream, and then return
some of it.
@@ -412,10 +414,10 @@ async def receive_some(self, max_bytes=None):
"""
- def __aiter__(self):
+ def __aiter__(self) -> ReceiveStream:
return self
- async def __anext__(self):
+ async def __anext__(self) -> bytes | bytearray:
data = await self.receive_some()
if not data:
raise StopAsyncIteration
@@ -445,7 +447,7 @@ class HalfCloseableStream(Stream):
__slots__ = ()
@abstractmethod
- async def send_eof(self):
+ async def send_eof(self) -> None:
"""Send an end-of-file indication on this stream, if possible.
The difference between :meth:`send_eof` and
@@ -631,7 +633,7 @@ async def receive(self) -> ReceiveType:
"""
- def __aiter__(self):
+ def __aiter__(self) -> ReceiveChannel[ReceiveType]:
return self
async def __anext__(self) -> ReceiveType:
diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py
index 0e95e4e5c5..a9db49584f 100644
--- a/trio/_core/_mock_clock.py
+++ b/trio/_core/_mock_clock.py
@@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final):
"""
- def __init__(self, rate=0.0, autojump_threshold=inf):
+ def __init__(self, rate: float = 0.0, autojump_threshold: float = inf):
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
@@ -77,17 +77,17 @@ def __init__(self, rate=0.0, autojump_threshold=inf):
self.rate = rate
self.autojump_threshold = autojump_threshold
- def __repr__(self):
+ def __repr__(self) -> str:
return "".format(
self.current_time(), self._rate, id(self)
)
@property
- def rate(self):
+ def rate(self) -> float:
return self._rate
@rate.setter
- def rate(self, new_rate):
+ def rate(self, new_rate: float) -> None:
if new_rate < 0:
raise ValueError("rate must be >= 0")
else:
@@ -98,11 +98,11 @@ def rate(self, new_rate):
self._rate = float(new_rate)
@property
- def autojump_threshold(self):
+ def autojump_threshold(self) -> float:
return self._autojump_threshold
@autojump_threshold.setter
- def autojump_threshold(self, new_autojump_threshold):
+ def autojump_threshold(self, new_autojump_threshold: float) -> None:
self._autojump_threshold = float(new_autojump_threshold)
self._try_resync_autojump_threshold()
@@ -112,7 +112,7 @@ def autojump_threshold(self, new_autojump_threshold):
# API. Discussion:
#
# https://github.com/python-trio/trio/issues/1587
- def _try_resync_autojump_threshold(self):
+ def _try_resync_autojump_threshold(self) -> None:
try:
runner = GLOBAL_RUN_CONTEXT.runner
if runner.is_guest:
@@ -124,24 +124,24 @@ def _try_resync_autojump_threshold(self):
# Invoked by the run loop when runner.clock_autojump_threshold is
# exceeded.
- def _autojump(self):
+ def _autojump(self) -> None:
statistics = _core.current_statistics()
jump = statistics.seconds_to_next_deadline
if 0 < jump < inf:
self.jump(jump)
- def _real_to_virtual(self, real):
+ def _real_to_virtual(self, real: float) -> float:
real_offset = real - self._real_base
virtual_offset = self._rate * real_offset
return self._virtual_base + virtual_offset
- def start_clock(self):
+ def start_clock(self) -> None:
self._try_resync_autojump_threshold()
- def current_time(self):
+ def current_time(self) -> float:
return self._real_to_virtual(self._real_clock())
- def deadline_to_sleep_time(self, deadline):
+ def deadline_to_sleep_time(self, deadline: float) -> float:
virtual_timeout = deadline - self.current_time()
if virtual_timeout <= 0:
return 0
@@ -150,7 +150,7 @@ def deadline_to_sleep_time(self, deadline):
else:
return 999999999
- def jump(self, seconds):
+ def jump(self, seconds) -> None:
"""Manually advance the clock by the given number of seconds.
Args:
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index f38123540f..42fc60614c 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -70,16 +70,24 @@
#
# See: https://github.com/python-trio/trio/issues/53
+from __future__ import annotations
+
import attr
from collections import OrderedDict
+from typing import TYPE_CHECKING, cast
+from collections.abc import Iterator
+import math
from .. import _core
from .._util import Final
+if TYPE_CHECKING:
+ from ._run import Task
+
@attr.s(frozen=True, slots=True)
class _ParkingLotStatistics:
- tasks_waiting = attr.ib()
+ tasks_waiting: int = attr.ib()
@attr.s(eq=False, hash=False, slots=True)
@@ -98,13 +106,13 @@ class ParkingLot(metaclass=Final):
# {task: None}, we just want a deque where we can quickly delete random
# items
- _parked = attr.ib(factory=OrderedDict, init=False)
+ _parked: OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False)
- def __len__(self):
+ def __len__(self) -> int:
"""Returns the number of parked tasks."""
return len(self._parked)
- def __bool__(self):
+ def __bool__(self) -> bool:
"""True if there are parked tasks, False otherwise."""
return bool(self._parked)
@@ -113,7 +121,7 @@ def __bool__(self):
# line (for false wakeups), then we could have it return a ticket that
# abstracts the "place in line" concept.
@_core.enable_ki_protection
- async def park(self):
+ async def park(self) -> None:
"""Park the current task until woken by a call to :meth:`unpark` or
:meth:`unpark_all`.
@@ -128,13 +136,15 @@ def abort_fn(_):
await _core.wait_task_rescheduled(abort_fn)
- def _pop_several(self, count):
- for _ in range(min(count, len(self._parked))):
+ def _pop_several(self, count: int | float) -> Iterator[Task]:
+ if isinstance(count, float):
+ assert math.isinf(count)
+ for _ in range(cast(int, min(count, len(self._parked)))):
task, _ = self._parked.popitem(last=False)
yield task
@_core.enable_ki_protection
- def unpark(self, *, count=1):
+ def unpark(self, *, count: int | float) -> list[Task]:
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
@@ -142,7 +152,7 @@ def unpark(self, *, count=1):
are available and then returns successfully.
Args:
- count (int): the number of tasks to unpark.
+ count (int | math.inf): the number of tasks to unpark.
"""
tasks = list(self._pop_several(count))
@@ -150,12 +160,12 @@ def unpark(self, *, count=1):
_core.reschedule(task)
return tasks
- def unpark_all(self):
+ def unpark_all(self) -> list[Task]:
"""Unpark all parked tasks."""
return self.unpark(count=len(self))
@_core.enable_ki_protection
- def repark(self, new_lot, *, count=1):
+ def repark(self, new_lot: ParkingLot, *, count: int = 1) -> None:
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
@@ -194,7 +204,7 @@ async def main():
new_lot._parked[task] = None
task.custom_sleep_data = new_lot
- def repark_all(self, new_lot):
+ def repark_all(self, new_lot: ParkingLot) -> None:
"""Move all parked tasks from one :class:`ParkingLot` object to
another.
@@ -203,7 +213,7 @@ def repark_all(self, new_lot):
"""
return self.repark(new_lot, count=len(self))
- def statistics(self):
+ def statistics(self) -> _ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 2727fe1e89..0664cdf577 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -50,6 +50,10 @@
if TYPE_CHECKING:
# An unfortunate name collision here with trio._util.Final
from typing_extensions import Final as FinalT
+ from collections.abc import Coroutine
+ from types import FrameType
+ import outcome
+ import contextvars
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
@@ -274,7 +278,7 @@ class CancelStatus:
# Our associated cancel scope. Can be any object with attributes
# `deadline`, `shield`, and `cancel_called`, but in current usage
# is always a CancelScope object. Must not be None.
- _scope = attr.ib()
+ _scope: CancelScope = attr.ib()
# True iff the tasks in self._tasks should receive cancellations
# when they checkpoint. Always True when scope.cancel_called is True;
@@ -284,31 +288,31 @@ class CancelStatus:
# effectively cancelled due to the cancel scope two levels out
# becoming cancelled, but then the cancel scope one level out
# becomes shielded so we're not effectively cancelled anymore.
- effectively_cancelled = attr.ib(default=False)
+ effectively_cancelled: bool = attr.ib(default=False)
# The CancelStatus whose cancellations can propagate to us; we
# become effectively cancelled when they do, unless scope.shield
# is True. May be None (for the outermost CancelStatus in a call
# to trio.run(), briefly during TaskStatus.started(), or during
# recovery from mis-nesting of cancel scopes).
- _parent = attr.ib(default=None, repr=False)
+ _parent: CancelStatus | None = attr.ib(default=None, repr=False)
# All of the CancelStatuses that have this CancelStatus as their parent.
- _children = attr.ib(factory=set, init=False, repr=False)
+ _children: set[CancelStatus] = attr.ib(factory=set, init=False, repr=False)
# Tasks whose cancellation state is currently tied directly to
# the cancellation state of this CancelStatus object. Don't modify
# this directly; instead, use Task._activate_cancel_status().
# Invariant: all(task._cancel_status is self for task in self._tasks)
- _tasks = attr.ib(factory=set, init=False, repr=False)
+ _tasks: set[Task] = attr.ib(factory=set, init=False, repr=False)
# Set to True on still-active cancel statuses that are children
# of a cancel status that's been closed. This is used to permit
# recovery from mis-nested cancel scopes (well, at least enough
# recovery to show a useful traceback).
- abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False)
+ abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
if self._parent is not None:
self._parent._children.add(self)
self.recalculate()
@@ -316,11 +320,11 @@ def __attrs_post_init__(self):
# parent/children/tasks accessors are used by TaskStatus.started()
@property
- def parent(self):
+ def parent(self) -> CancelStatus | None:
return self._parent
@parent.setter
- def parent(self, parent):
+ def parent(self, parent: CancelStatus) -> None:
if self._parent is not None:
self._parent._children.remove(self)
self._parent = parent
@@ -329,14 +333,14 @@ def parent(self, parent):
self.recalculate()
@property
- def children(self):
+ def children(self) -> frozenset[CancelStatus]:
return frozenset(self._children)
@property
- def tasks(self):
+ def tasks(self) -> frozenset[Task]:
return frozenset(self._tasks)
- def encloses(self, other):
+ def encloses(self, other: CancelStatus | None) -> bool:
"""Returns true if this cancel status is a direct or indirect
parent of cancel status *other*, or if *other* is *self*.
"""
@@ -346,7 +350,7 @@ def encloses(self, other):
other = other.parent
return False
- def close(self):
+ def close(self) -> None:
self.parent = None # now we're not a child of self.parent anymore
if self._tasks or self._children:
# Cancel scopes weren't exited in opposite order of being
@@ -406,7 +410,7 @@ def _mark_abandoned(self):
for child in self._children:
child._mark_abandoned()
- def effective_deadline(self):
+ def effective_deadline(self) -> float:
if self.effectively_cancelled:
return -inf
if self._parent is None or self._scope.shield:
@@ -854,7 +858,7 @@ class NurseryManager:
"""
- strict_exception_groups = attr.ib(default=False)
+ strict_exception_groups: bool = attr.ib(default=False)
@enable_ki_protection
async def __aenter__(self):
@@ -866,7 +870,12 @@ async def __aenter__(self):
return self._nursery
@enable_ki_protection
- async def __aexit__(self, etype, exc, tb):
+ async def __aexit__(
+ self,
+ etype: type[BaseException] | None,
+ exc: BaseException | None,
+ tb: TracebackType | None,
+ ) -> bool:
new_exc = await self._nursery._nested_child_finished(exc)
# Tracebacks show the 'raise' line below out of context, so let's give
# this variable a name that makes sense out of context.
@@ -889,18 +898,18 @@ async def __aexit__(self, etype, exc, tb):
# see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage
del _, combined_error_from_nursery, value, new_exc
- def __enter__(self):
+ def __enter__(self) -> None:
raise RuntimeError(
"use 'async with open_nursery(...)', not 'with open_nursery(...)'"
)
- def __exit__(self): # pragma: no cover
+ def __exit__(self) -> None: # pragma: no cover
assert False, """Never called, but should be defined"""
def open_nursery(
strict_exception_groups: bool | None = None,
-) -> AbstractAsyncContextManager[Nursery]:
+) -> NurseryManager:
"""Returns an async context manager which must be used to create a
new `Nursery`.
@@ -941,7 +950,12 @@ class Nursery(metaclass=NoPublicConstructor):
in response to some external event.
"""
- def __init__(self, parent_task, cancel_scope, strict_exception_groups):
+ def __init__(
+ self,
+ parent_task: Task,
+ cancel_scope: CancelScope,
+ strict_exception_groups: bool,
+ ):
self._parent_task = parent_task
self._strict_exception_groups = strict_exception_groups
parent_task._child_nurseries.append(self)
@@ -952,8 +966,8 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups):
# children.
self.cancel_scope = cancel_scope
assert self.cancel_scope._cancel_status is self._cancel_status
- self._children = set()
- self._pending_excs = []
+ self._children: set[Task] = set()
+ self._pending_excs: list[BaseException] = []
# The "nested child" is how this code refers to the contents of the
# nursery's 'async with' block, which acts like a child Task in all
# the ways we can make it.
@@ -963,17 +977,17 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups):
self._closed = False
@property
- def child_tasks(self):
+ def child_tasks(self) -> frozenset[Task]:
"""(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task`
objects which are still running."""
return frozenset(self._children)
@property
- def parent_task(self):
+ def parent_task(self) -> Task:
"(`~trio.lowlevel.Task`): The Task that opened this nursery."
return self._parent_task
- def _add_exc(self, exc):
+ def _add_exc(self, exc: BaseException) -> None:
self._pending_excs.append(exc)
self.cancel_scope.cancel()
@@ -1135,7 +1149,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED):
self._pending_starts -= 1
self._check_nursery_closed()
- def __del__(self):
+ def __del__(self) -> None:
assert not self._children
@@ -1146,12 +1160,11 @@ def __del__(self):
@attr.s(eq=False, hash=False, repr=False, slots=True)
class Task(metaclass=NoPublicConstructor):
- _parent_nursery = attr.ib()
- coro = attr.ib()
+ _parent_nursery: Nursery | None = attr.ib()
+ coro: Coroutine[Any, outcome.Outcome[object], Any] = attr.ib()
_runner = attr.ib()
- name = attr.ib()
- # PEP 567 contextvars context
- context = attr.ib()
+ name: str = attr.ib()
+ context: contextvars.Context = attr.ib()
_counter: int = attr.ib(init=False, factory=itertools.count().__next__)
# Invariant:
@@ -1167,24 +1180,27 @@ class Task(metaclass=NoPublicConstructor):
# Tasks start out unscheduled.
_next_send_fn = attr.ib(default=None)
_next_send = attr.ib(default=None)
- _abort_func = attr.ib(default=None)
- custom_sleep_data = attr.ib(default=None)
+ _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib(
+ default=None
+ )
+ # could possible be set with a TypeVar
+ custom_sleep_data: Any = attr.ib(default=None)
# For introspection and nursery.start()
- _child_nurseries = attr.ib(factory=list)
- _eventual_parent_nursery = attr.ib(default=None)
+ _child_nurseries: list[Nursery] = attr.ib(factory=list)
+ _eventual_parent_nursery: Nursery | None = attr.ib(default=None)
# these are counts of how many cancel/schedule points this task has
# executed, for assert{_no,}_checkpoints
# XX maybe these should be exposed as part of a statistics() method?
- _cancel_points = attr.ib(default=0)
- _schedule_points = attr.ib(default=0)
+ _cancel_points: int = attr.ib(default=0)
+ _schedule_points: int = attr.ib(default=0)
- def __repr__(self):
+ def __repr__(self) -> str:
return f""
@property
- def parent_nursery(self):
+ def parent_nursery(self) -> Nursery | None:
"""The nursery this task is inside (or None if this is the "init"
task).
@@ -1195,7 +1211,7 @@ def parent_nursery(self):
return self._parent_nursery
@property
- def eventual_parent_nursery(self):
+ def eventual_parent_nursery(self) -> Nursery | None:
"""The nursery this task will be inside after it calls
``task_status.started()``.
@@ -1207,7 +1223,7 @@ def eventual_parent_nursery(self):
return self._eventual_parent_nursery
@property
- def child_nurseries(self):
+ def child_nurseries(self) -> list[Nursery]:
"""The nurseries this task contains.
This is a list, with outer nurseries before inner nurseries.
@@ -1215,7 +1231,7 @@ def child_nurseries(self):
"""
return list(self._child_nurseries)
- def iter_await_frames(self):
+ def iter_await_frames(self) -> Iterator[tuple[FrameType, int]]:
"""Iterates recursively over the coroutine-like objects this
task is waiting on, yielding the frame and line number at each
frame.
@@ -1240,11 +1256,11 @@ def print_stack_for_task(task):
if hasattr(coro, "cr_frame"):
# A real coroutine
yield coro.cr_frame, coro.cr_frame.f_lineno
- coro = coro.cr_await
+ coro = coro.cr_await # type: ignore # TODO
elif hasattr(coro, "gi_frame"):
# A generator decorated with @types.coroutine
yield coro.gi_frame, coro.gi_frame.f_lineno
- coro = coro.gi_yieldfrom
+ coro = coro.gi_yieldfrom # type: ignore # TODO
elif coro.__class__.__name__ in [
"async_generator_athrow",
"async_generator_asend",
@@ -1268,9 +1284,9 @@ def print_stack_for_task(task):
# The CancelStatus object that is currently active for this task.
# Don't change this directly; instead, use _activate_cancel_status().
- _cancel_status = attr.ib(default=None, repr=False)
+ _cancel_status: CancelStatus = attr.ib(default=None, repr=False)
- def _activate_cancel_status(self, cancel_status):
+ def _activate_cancel_status(self, cancel_status: CancelStatus):
if self._cancel_status is not None:
self._cancel_status._tasks.remove(self)
self._cancel_status = cancel_status
@@ -1279,12 +1295,14 @@ def _activate_cancel_status(self, cancel_status):
if self._cancel_status.effectively_cancelled:
self._attempt_delivery_of_any_pending_cancel()
- def _attempt_abort(self, raise_cancel):
+ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None:
# Either the abort succeeds, in which case we will reschedule the
# task, or else it fails, in which case it will worry about
# rescheduling itself (hopefully eventually calling reraise to raise
# the given exception, but not necessarily).
- success = self._abort_func(raise_cancel)
+
+ # TODO: what happens if _abort_func is None?
+ success = self._abort_func(raise_cancel) # type: ignore
if type(success) is not Abort:
raise TrioInternalError("abort function must return Abort enum")
# We only attempt to abort once per blocking call, regardless of
@@ -1293,7 +1311,7 @@ def _attempt_abort(self, raise_cancel):
if success is Abort.SUCCEEDED:
self._runner.reschedule(self, capture(raise_cancel))
- def _attempt_delivery_of_any_pending_cancel(self):
+ def _attempt_delivery_of_any_pending_cancel(self) -> None:
if self._abort_func is None:
return
if not self._cancel_status.effectively_cancelled:
@@ -1304,12 +1322,12 @@ def raise_cancel():
self._attempt_abort(raise_cancel)
- def _attempt_delivery_of_pending_ki(self):
+ def _attempt_delivery_of_pending_ki(self) -> None:
assert self._runner.ki_pending
if self._abort_func is None:
return
- def raise_cancel():
+ def raise_cancel() -> NoReturn:
self._runner.ki_pending = False
raise KeyboardInterrupt
@@ -2435,17 +2453,17 @@ def unrolled_run(
class _TaskStatusIgnored:
- def __repr__(self):
+ def __repr__(self) -> str:
return "TASK_STATUS_IGNORED"
- def started(self, value=None):
+ def started(self, value: None = None):
pass
TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored()
-def current_task():
+def current_task() -> Task:
"""Return the :class:`Task` object representing the current task.
Returns:
@@ -2459,7 +2477,7 @@ def current_task():
raise RuntimeError("must be called from async context") from None
-def current_effective_deadline():
+def current_effective_deadline() -> float:
"""Returns the current effective deadline for the current task.
This function examines all the cancellation scopes that are currently in
@@ -2486,7 +2504,7 @@ def current_effective_deadline():
return current_task()._cancel_status.effective_deadline()
-async def checkpoint():
+async def checkpoint() -> None:
"""A pure :ref:`checkpoint `.
This checks for cancellation and allows other tasks to be scheduled,
@@ -2513,7 +2531,7 @@ async def checkpoint():
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
-async def checkpoint_if_cancelled():
+async def checkpoint_if_cancelled() -> None:
"""Issue a :ref:`checkpoint ` if the calling context has been
cancelled.
diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py
index c31b4fdbf3..191c2a76d4 100644
--- a/trio/_highlevel_generic.py
+++ b/trio/_highlevel_generic.py
@@ -1,12 +1,17 @@
+from __future__ import annotations
import attr
import trio
from .abc import HalfCloseableStream
from trio._util import Final
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from .abc import SendStream, ReceiveStream, AsyncResource
-async def aclose_forcefully(resource):
+
+async def aclose_forcefully(resource: AsyncResource) -> None:
"""Close an async resource or async generator immediately, without
blocking to do any graceful cleanup.
@@ -72,18 +77,18 @@ class StapledStream(HalfCloseableStream, metaclass=Final):
"""
- send_stream = attr.ib()
- receive_stream = attr.ib()
+ send_stream: SendStream = attr.ib()
+ receive_stream: ReceiveStream = attr.ib()
- async def send_all(self, data):
+ async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Calls ``self.send_stream.send_all``."""
return await self.send_stream.send_all(data)
- async def wait_send_all_might_not_block(self):
+ async def wait_send_all_might_not_block(self) -> None:
"""Calls ``self.send_stream.wait_send_all_might_not_block``."""
return await self.send_stream.wait_send_all_might_not_block()
- async def send_eof(self):
+ async def send_eof(self) -> None:
"""Shuts down the send side of the stream.
If ``self.send_stream.send_eof`` exists, then calls it. Otherwise,
@@ -91,15 +96,15 @@ async def send_eof(self):
"""
if hasattr(self.send_stream, "send_eof"):
- return await self.send_stream.send_eof()
+ return await self.send_stream.send_eof() # type: ignore
else:
return await self.send_stream.aclose()
- async def receive_some(self, max_bytes=None):
+ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
"""Calls ``self.receive_stream.receive_some``."""
return await self.receive_stream.receive_some(max_bytes)
- async def aclose(self):
+ async def aclose(self) -> None:
"""Calls ``aclose`` on both underlying streams."""
try:
await self.send_stream.aclose()
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index 1e8dc16ebc..5bd6b68f1c 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -1,13 +1,18 @@
# "High-level" networking interface
+from __future__ import annotations
import errno
from contextlib import contextmanager
+from typing import TYPE_CHECKING
import trio
from . import socket as tsocket
from ._util import ConflictDetector, Final
from .abc import HalfCloseableStream, Listener
+if TYPE_CHECKING:
+ from ._socket import _SocketType as SocketType
+
# XX TODO: this number was picked arbitrarily. We should do experiments to
# tune it. (Or make it dynamic -- one idea is to start small and increase it
# if we observe single reads filling up the whole buffer, at least within some
@@ -57,7 +62,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final):
"""
- def __init__(self, socket):
+ def __init__(self, socket: SocketType):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
@@ -108,14 +113,14 @@ async def send_all(self, data):
sent = await self.socket.send(remaining)
total_sent += sent
- async def wait_send_all_might_not_block(self):
+ async def wait_send_all_might_not_block(self) -> None:
with self._send_conflict_detector:
if self.socket.fileno() == -1:
raise trio.ClosedResourceError
with _translate_socket_errors_to_stream_errors():
await self.socket.wait_writable()
- async def send_eof(self):
+ async def send_eof(self) -> None:
with self._send_conflict_detector:
await trio.lowlevel.checkpoint()
# On macOS, calling shutdown a second time raises ENOTCONN, but
@@ -125,7 +130,7 @@ async def send_eof(self):
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
- async def receive_some(self, max_bytes=None):
+ async def receive_some(self, max_bytes: int | None = None):
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
@@ -133,7 +138,7 @@ async def receive_some(self, max_bytes=None):
with _translate_socket_errors_to_stream_errors():
return await self.socket.recv(max_bytes)
- async def aclose(self):
+ async def aclose(self) -> None:
self.socket.close()
await trio.lowlevel.checkpoint()
@@ -330,7 +335,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final):
"""
- def __init__(self, socket):
+ def __init__(self, socket: SocketType):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
@@ -346,7 +351,7 @@ def __init__(self, socket):
self.socket = socket
- async def accept(self):
+ async def accept(self) -> SocketStream:
"""Accept an incoming connection.
Returns:
@@ -374,7 +379,7 @@ async def accept(self):
else:
return SocketStream(sock)
- async def aclose(self):
+ async def aclose(self) -> None:
"""Close this listener and its underlying socket."""
self.socket.close()
await trio.lowlevel.checkpoint()
diff --git a/trio/_socket.py b/trio/_socket.py
index 2889f48113..38a6fa0019 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -1,15 +1,21 @@
+from __future__ import annotations
+
import os
import sys
import select
import socket as _stdlib_socket
from functools import wraps as _wraps
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
import idna as _idna
import trio
from . import _core
+if TYPE_CHECKING:
+ from collections.abc import Iterable
+ from typing_extensions import Self
+
# Usage:
#
@@ -429,7 +435,7 @@ def __init__(self):
class _SocketType(SocketType):
- def __init__(self, sock):
+ def __init__(self, sock: _stdlib_socket.socket):
if type(sock) is not _stdlib_socket.socket:
# For example, ssl.SSLSocket subclasses socket.socket, but we
# certainly don't want to blindly wrap one of those.
@@ -472,44 +478,45 @@ def __getattr__(self, name):
return getattr(self._sock, name)
raise AttributeError(name)
- def __dir__(self):
- return super().__dir__() + list(self._forward)
+ def __dir__(self) -> Iterable[str]:
+ yield from super().__dir__()
+ yield from self._forward
- def __enter__(self):
+ def __enter__(self) -> Self:
return self
- def __exit__(self, *exc_info):
+ def __exit__(self, *exc_info: object) -> None:
return self._sock.__exit__(*exc_info)
@property
- def family(self):
+ def family(self) -> _stdlib_socket.AddressFamily:
return self._sock.family
@property
- def type(self):
+ def type(self) -> _stdlib_socket.SocketKind:
return self._sock.type
@property
- def proto(self):
+ def proto(self) -> int:
return self._sock.proto
@property
- def did_shutdown_SHUT_WR(self):
+ def did_shutdown_SHUT_WR(self) -> bool:
return self._did_shutdown_SHUT_WR
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(self._sock).replace("socket.socket", "trio.socket.socket")
- def dup(self):
+ def dup(self) -> _SocketType:
"""Same as :meth:`socket.socket.dup`."""
return _SocketType(self._sock.dup())
- def close(self):
+ def close(self) -> None:
if self._sock.fileno() != -1:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
- async def bind(self, address):
+ async def bind(self, address: tuple[Any, ...] | str | bytes) -> None:
address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
@@ -518,7 +525,7 @@ async def bind(self, address):
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
- return await trio.to_thread.run_sync(self._sock.bind, address)
+ return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return]
else:
# POSIX actually says that bind can return EWOULDBLOCK and
# complete asynchronously, like connect. But in practice AFAICT
@@ -527,14 +534,14 @@ async def bind(self, address):
await trio.lowlevel.checkpoint()
return self._sock.bind(address)
- def shutdown(self, flag):
+ def shutdown(self, flag: int) -> None:
# no need to worry about return value b/c always returns None:
self._sock.shutdown(flag)
# only do this if the call succeeded:
if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]:
self._did_shutdown_SHUT_WR = True
- def is_readable(self):
+ def is_readable(self) -> bool:
# use select.select on Windows, and select.poll everywhere else
if sys.platform == "win32":
rready, _, _ = select.select([self._sock], [], [], 0)
@@ -543,7 +550,7 @@ def is_readable(self):
p.register(self._sock, select.POLLIN)
return bool(p.poll(0))
- async def wait_writable(self):
+ async def wait_writable(self) -> None:
await _core.wait_writable(self._sock)
async def _resolve_address_nocp(self, address, *, local):
diff --git a/trio/_sync.py b/trio/_sync.py
index 8d2fdc0a2d..08eae512d2 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
import math
import attr
@@ -8,10 +11,14 @@
from ._core import enable_ki_protection, ParkingLot
from ._util import Final
+if TYPE_CHECKING:
+ from ._core import Task
+ from ._core._parking_lot import _ParkingLotStatistics
+
@attr.s(frozen=True)
class _EventStatistics:
- tasks_waiting = attr.ib()
+ tasks_waiting: int = attr.ib()
@attr.s(repr=False, eq=False, hash=False, slots=True)
@@ -41,15 +48,15 @@ class Event(metaclass=Final):
"""
- _tasks = attr.ib(factory=set, init=False)
- _flag = attr.ib(default=False, init=False)
+ _tasks: set[Task] = attr.ib(factory=set, init=False)
+ _flag: bool = attr.ib(default=False, init=False)
- def is_set(self):
+ def is_set(self) -> bool:
"""Return the current value of the internal flag."""
return self._flag
@enable_ki_protection
- def set(self):
+ def set(self) -> None:
"""Set the internal flag value to True, and wake any waiting tasks."""
if not self._flag:
self._flag = True
@@ -57,7 +64,7 @@ def set(self):
_core.reschedule(task)
self._tasks.clear()
- async def wait(self):
+ async def wait(self) -> None:
"""Block until the internal flag value becomes True.
If it's already True, then this method returns immediately.
@@ -75,7 +82,7 @@ def abort_fn(_):
await _core.wait_task_rescheduled(abort_fn)
- def statistics(self):
+ def statistics(self) -> _EventStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -89,20 +96,20 @@ def statistics(self):
class AsyncContextManagerMixin:
@enable_ki_protection
- async def __aenter__(self):
- await self.acquire()
+ async def __aenter__(self) -> None:
+ await self.acquire() # type: ignore[attr-defined]
@enable_ki_protection
- async def __aexit__(self, *args):
- self.release()
+ async def __aexit__(self, *args: object) -> None:
+ self.release() # type: ignore[attr-defined]
@attr.s(frozen=True)
class _CapacityLimiterStatistics:
- borrowed_tokens = attr.ib()
- total_tokens = attr.ib()
- borrowers = attr.ib()
- tasks_waiting = attr.ib()
+ borrowed_tokens: int = attr.ib()
+ total_tokens: int | float = attr.ib()
+ borrowers: list[Task] = attr.ib()
+ tasks_waiting: int = attr.ib()
class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
@@ -159,22 +166,23 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
"""
- def __init__(self, total_tokens):
+ # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
+ def __init__(self, total_tokens: int | float):
self._lot = ParkingLot()
- self._borrowers = set()
+ self._borrowers: set[Task] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
- self._pending_borrowers = {}
+ self._pending_borrowers: dict[Task, Task] = {}
# invoke the property setter for validation
- self.total_tokens = total_tokens
+ self.total_tokens: int | float = total_tokens
assert self._total_tokens == total_tokens
- def __repr__(self):
+ def __repr__(self) -> str:
return "".format(
id(self), len(self._borrowers), self._total_tokens, len(self._lot)
)
@property
- def total_tokens(self):
+ def total_tokens(self) -> int | float:
"""The total capacity available.
You can change :attr:`total_tokens` by assigning to this attribute. If
@@ -189,7 +197,7 @@ def total_tokens(self):
return self._total_tokens
@total_tokens.setter
- def total_tokens(self, new_total_tokens):
+ def total_tokens(self, new_total_tokens: int | float) -> None:
if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf:
raise TypeError("total_tokens must be an int or math.inf")
if new_total_tokens < 1:
@@ -197,23 +205,23 @@ def total_tokens(self, new_total_tokens):
self._total_tokens = new_total_tokens
self._wake_waiters()
- def _wake_waiters(self):
+ def _wake_waiters(self) -> None:
available = self._total_tokens - len(self._borrowers)
for woken in self._lot.unpark(count=available):
self._borrowers.add(self._pending_borrowers.pop(woken))
@property
- def borrowed_tokens(self):
+ def borrowed_tokens(self) -> int:
"""The amount of capacity that's currently in use."""
return len(self._borrowers)
@property
- def available_tokens(self):
+ def available_tokens(self) -> int | float:
"""The amount of capacity that's available to use."""
return self.total_tokens - self.borrowed_tokens
@enable_ki_protection
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
"""Borrow a token from the sack, without blocking.
Raises:
@@ -225,7 +233,7 @@ def acquire_nowait(self):
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
- def acquire_on_behalf_of_nowait(self, borrower):
+ def acquire_on_behalf_of_nowait(self, borrower: Task) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
@@ -253,7 +261,7 @@ def acquire_on_behalf_of_nowait(self, borrower):
raise trio.WouldBlock
@enable_ki_protection
- async def acquire(self):
+ async def acquire(self) -> None:
"""Borrow a token from the sack, blocking if necessary.
Raises:
@@ -264,7 +272,7 @@ async def acquire(self):
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- async def acquire_on_behalf_of(self, borrower):
+ async def acquire_on_behalf_of(self, borrower: Task) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
@@ -293,7 +301,7 @@ async def acquire_on_behalf_of(self, borrower):
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
- def release(self):
+ def release(self) -> None:
"""Put a token back into the sack.
Raises:
@@ -304,7 +312,7 @@ def release(self):
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- def release_on_behalf_of(self, borrower):
+ def release_on_behalf_of(self, borrower: Task) -> None:
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
@@ -319,7 +327,7 @@ def release_on_behalf_of(self, borrower):
self._borrowers.remove(borrower)
self._wake_waiters()
- def statistics(self):
+ def statistics(self) -> _CapacityLimiterStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -373,7 +381,7 @@ class Semaphore(AsyncContextManagerMixin, metaclass=Final):
"""
- def __init__(self, initial_value, *, max_value=None):
+ def __init__(self, initial_value: int, *, max_value: int | None = None):
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an int")
if initial_value < 0:
@@ -391,7 +399,7 @@ def __init__(self, initial_value, *, max_value=None):
self._value = initial_value
self._max_value = max_value
- def __repr__(self):
+ def __repr__(self) -> str:
if self._max_value is None:
max_value_str = ""
else:
@@ -401,17 +409,17 @@ def __repr__(self):
)
@property
- def value(self):
+ def value(self) -> int:
"""The current value of the semaphore."""
return self._value
@property
- def max_value(self):
+ def max_value(self) -> int | None:
"""The maximum allowed value. May be None to indicate no limit."""
return self._max_value
@enable_ki_protection
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
"""Attempt to decrement the semaphore value, without blocking.
Raises:
@@ -425,7 +433,7 @@ def acquire_nowait(self):
raise trio.WouldBlock
@enable_ki_protection
- async def acquire(self):
+ async def acquire(self) -> None:
"""Decrement the semaphore value, blocking if necessary to avoid
letting it drop below zero.
@@ -439,7 +447,7 @@ async def acquire(self):
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
- def release(self):
+ def release(self) -> None:
"""Increment the semaphore value, possibly waking a task blocked in
:meth:`acquire`.
@@ -456,7 +464,7 @@ def release(self):
raise ValueError("semaphore released too many times")
self._value += 1
- def statistics(self):
+ def statistics(self) -> _ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -470,17 +478,17 @@ def statistics(self):
@attr.s(frozen=True)
class _LockStatistics:
- locked = attr.ib()
- owner = attr.ib()
- tasks_waiting = attr.ib()
+ locked: bool = attr.ib()
+ owner: Task | None = attr.ib()
+ tasks_waiting: int = attr.ib()
@attr.s(eq=False, hash=False, repr=False)
class _LockImpl(AsyncContextManagerMixin):
- _lot = attr.ib(factory=ParkingLot, init=False)
- _owner = attr.ib(default=None, init=False)
+ _lot: ParkingLot = attr.ib(factory=ParkingLot, init=False)
+ _owner: Task | None = attr.ib(default=None, init=False)
- def __repr__(self):
+ def __repr__(self) -> str:
if self.locked():
s1 = "locked"
s2 = f" with {len(self._lot)} waiters"
@@ -491,7 +499,7 @@ def __repr__(self):
s1, self.__class__.__name__, id(self), s2
)
- def locked(self):
+ def locked(self) -> bool:
"""Check whether the lock is currently held.
Returns:
@@ -501,7 +509,7 @@ def locked(self):
return self._owner is not None
@enable_ki_protection
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
"""Attempt to acquire the lock, without blocking.
Raises:
@@ -519,7 +527,7 @@ def acquire_nowait(self):
raise trio.WouldBlock
@enable_ki_protection
- async def acquire(self):
+ async def acquire(self) -> None:
"""Acquire the lock, blocking if necessary."""
await trio.lowlevel.checkpoint_if_cancelled()
try:
@@ -533,7 +541,7 @@ async def acquire(self):
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
- def release(self):
+ def release(self) -> None:
"""Release the lock.
Raises:
@@ -548,7 +556,7 @@ def release(self):
else:
self._owner = None
- def statistics(self):
+ def statistics(self) -> _LockStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -644,8 +652,8 @@ class StrictFIFOLock(_LockImpl, metaclass=Final):
@attr.s(frozen=True)
class _ConditionStatistics:
- tasks_waiting = attr.ib()
- lock_statistics = attr.ib()
+ tasks_waiting: int = attr.ib()
+ lock_statistics: _LockStatistics = attr.ib()
class Condition(AsyncContextManagerMixin, metaclass=Final):
@@ -663,7 +671,7 @@ class Condition(AsyncContextManagerMixin, metaclass=Final):
"""
- def __init__(self, lock=None):
+ def __init__(self, lock: Lock | None = None):
if lock is None:
lock = Lock()
if not type(lock) is Lock:
@@ -671,7 +679,7 @@ def __init__(self, lock=None):
self._lock = lock
self._lot = trio.lowlevel.ParkingLot()
- def locked(self):
+ def locked(self) -> bool:
"""Check whether the underlying lock is currently held.
Returns:
@@ -680,7 +688,7 @@ def locked(self):
"""
return self._lock.locked()
- def acquire_nowait(self):
+ def acquire_nowait(self) -> None:
"""Attempt to acquire the underlying lock, without blocking.
Raises:
@@ -689,16 +697,16 @@ def acquire_nowait(self):
"""
return self._lock.acquire_nowait()
- async def acquire(self):
+ async def acquire(self) -> None:
"""Acquire the underlying lock, blocking if necessary."""
await self._lock.acquire()
- def release(self):
+ def release(self) -> None:
"""Release the underlying lock."""
self._lock.release()
@enable_ki_protection
- async def wait(self):
+ async def wait(self) -> None:
"""Wait for another task to call :meth:`notify` or
:meth:`notify_all`.
@@ -733,7 +741,7 @@ async def wait(self):
await self.acquire()
raise
- def notify(self, n=1):
+ def notify(self, n: int = 1) -> None:
"""Wake one or more tasks that are blocked in :meth:`wait`.
Args:
@@ -747,7 +755,7 @@ def notify(self, n=1):
raise RuntimeError("must hold the lock to notify")
self._lot.repark(self._lock._lot, count=n)
- def notify_all(self):
+ def notify_all(self) -> None:
"""Wake all tasks that are currently blocked in :meth:`wait`.
Raises:
@@ -758,7 +766,7 @@ def notify_all(self):
raise RuntimeError("must hold the lock to notify")
self._lot.repark_all(self._lock._lot)
- def statistics(self):
+ def statistics(self) -> _ConditionStatistics:
r"""Return an object containing debugging information.
Currently the following fields are defined:
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index e54af12444..53dc90d5d9 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8317152103559871,
+ "completenessScore": 0.8737864077669902,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 514,
- "withUnknownType": 103
+ "withKnownType": 540,
+ "withUnknownType": 77
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -45,91 +45,36 @@
}
],
"otherSymbolCounts": {
- "withAmbiguousType": 14,
- "withKnownType": 244,
- "withUnknownType": 224
+ "withAmbiguousType": 8,
+ "withKnownType": 441,
+ "withUnknownType": 143
},
"packageName": "trio",
"symbols": [
"trio.run",
- "trio.current_effective_deadline",
- "trio._core._run._TaskStatusIgnored.__repr__",
"trio._core._run._TaskStatusIgnored.started",
"trio.current_time",
- "trio._core._run.Nursery.__init__",
- "trio._core._run.Nursery.child_tasks",
- "trio._core._run.Nursery.parent_task",
"trio._core._run.Nursery.start_soon",
"trio._core._run.Nursery.start",
- "trio._core._run.Nursery.__del__",
- "trio._sync.Event.is_set",
- "trio._sync.Event.wait",
- "trio._sync.Event.statistics",
- "trio._sync.CapacityLimiter.__init__",
- "trio._sync.CapacityLimiter.__repr__",
- "trio._sync.CapacityLimiter.total_tokens",
- "trio._sync.CapacityLimiter.borrowed_tokens",
- "trio._sync.CapacityLimiter.available_tokens",
- "trio._sync.CapacityLimiter.statistics",
- "trio._sync.Semaphore.__init__",
- "trio._sync.Semaphore.__repr__",
- "trio._sync.Semaphore.value",
- "trio._sync.Semaphore.max_value",
- "trio._sync.Semaphore.statistics",
- "trio._sync.Lock",
- "trio._sync._LockImpl.__repr__",
- "trio._sync._LockImpl.locked",
- "trio._sync._LockImpl.statistics",
- "trio._sync.StrictFIFOLock",
- "trio._sync.Condition.__init__",
- "trio._sync.Condition.locked",
- "trio._sync.Condition.acquire_nowait",
- "trio._sync.Condition.acquire",
- "trio._sync.Condition.release",
- "trio._sync.Condition.notify",
- "trio._sync.Condition.notify_all",
- "trio._sync.Condition.statistics",
- "trio.aclose_forcefully",
- "trio._highlevel_generic.StapledStream",
- "trio._highlevel_generic.StapledStream.send_stream",
- "trio._highlevel_generic.StapledStream.receive_stream",
- "trio._highlevel_generic.StapledStream.send_all",
- "trio._highlevel_generic.StapledStream.wait_send_all_might_not_block",
- "trio._highlevel_generic.StapledStream.send_eof",
- "trio._highlevel_generic.StapledStream.receive_some",
- "trio._highlevel_generic.StapledStream.aclose",
- "trio._abc.HalfCloseableStream",
- "trio._abc.HalfCloseableStream.send_eof",
- "trio._abc.Stream",
- "trio._abc.SendStream",
- "trio._abc.SendStream.send_all",
- "trio._abc.SendStream.wait_send_all_might_not_block",
- "trio._abc.AsyncResource.aclose",
- "trio._abc.AsyncResource.__aenter__",
- "trio._abc.AsyncResource.__aexit__",
- "trio._abc.ReceiveStream",
- "trio._abc.ReceiveStream.receive_some",
- "trio._abc.ReceiveStream.__aiter__",
- "trio._abc.ReceiveStream.__anext__",
- "trio._channel.MemorySendChannel",
- "trio._abc.SendChannel",
- "trio._channel.MemoryReceiveChannel",
- "trio._abc.ReceiveChannel",
- "trio._abc.ReceiveChannel.__aiter__",
- "trio._highlevel_socket.SocketStream",
"trio._highlevel_socket.SocketStream.__init__",
+ "trio._socket._SocketType.__getattr__",
+ "trio._socket._SocketType.accept",
+ "trio._socket._SocketType.connect",
+ "trio._socket._SocketType.recv",
+ "trio._socket._SocketType.recv_into",
+ "trio._socket._SocketType.recvfrom",
+ "trio._socket._SocketType.recvfrom_into",
+ "trio._socket._SocketType.recvmsg",
+ "trio._socket._SocketType.recvmsg_into",
+ "trio._socket._SocketType.send",
+ "trio._socket._SocketType.sendto",
+ "trio._socket._SocketType.sendmsg",
"trio._highlevel_socket.SocketStream.send_all",
- "trio._highlevel_socket.SocketStream.wait_send_all_might_not_block",
- "trio._highlevel_socket.SocketStream.send_eof",
"trio._highlevel_socket.SocketStream.receive_some",
- "trio._highlevel_socket.SocketStream.aclose",
"trio._highlevel_socket.SocketStream.setsockopt",
"trio._highlevel_socket.SocketStream.getsockopt",
"trio._highlevel_socket.SocketListener",
"trio._highlevel_socket.SocketListener.__init__",
- "trio._highlevel_socket.SocketListener.accept",
- "trio._highlevel_socket.SocketListener.aclose",
- "trio._abc.Listener",
"trio._abc.Listener.accept",
"trio.open_file",
"trio.wrap_file",
@@ -147,7 +92,6 @@
"trio._path.AsyncAutoWrapperType.generate_wraps",
"trio._path.AsyncAutoWrapperType.generate_magic",
"trio._path.AsyncAutoWrapperType.generate_iter",
- "trio._subprocess.Process",
"trio._subprocess.Process.encoding",
"trio._subprocess.Process.errors",
"trio._subprocess.Process.__init__",
@@ -163,7 +107,6 @@
"trio._subprocess.Process.args",
"trio._subprocess.Process.pid",
"trio.run_process",
- "trio._ssl.SSLStream",
"trio._ssl.SSLStream.__init__",
"trio._ssl.SSLStream.__getattr__",
"trio._ssl.SSLStream.__setattr__",
@@ -188,7 +131,6 @@
"trio._dtls.DTLSEndpoint.connect",
"trio._dtls.DTLSEndpoint.socket",
"trio._dtls.DTLSEndpoint.incoming_packets_buffer",
- "trio._dtls.DTLSChannel",
"trio._dtls.DTLSChannel.__init__",
"trio._dtls.DTLSChannel.close",
"trio._dtls.DTLSChannel.__enter__",
@@ -200,7 +142,6 @@
"trio._dtls.DTLSChannel.set_ciphertext_mtu",
"trio._dtls.DTLSChannel.get_cleartext_mtu",
"trio._dtls.DTLSChannel.statistics",
- "trio._abc.Channel",
"trio.serve_listeners",
"trio.open_tcp_stream",
"trio.open_tcp_listeners",
@@ -210,9 +151,6 @@
"trio.open_ssl_over_tcp_listeners",
"trio.serve_ssl_over_tcp",
"trio.__deprecated_attributes__",
- "trio._abc.Clock.start_clock",
- "trio._abc.Clock.current_time",
- "trio._abc.Clock.deadline_to_sleep_time",
"trio._abc.Instrument.before_run",
"trio._abc.Instrument.after_run",
"trio._abc.Instrument.task_spawned",
@@ -229,22 +167,6 @@
"trio.from_thread.run_sync",
"trio.lowlevel.cancel_shielded_checkpoint",
"trio.lowlevel.currently_ki_protected",
- "trio._core._run.Task.coro",
- "trio._core._run.Task.name",
- "trio._core._run.Task.context",
- "trio._core._run.Task.custom_sleep_data",
- "trio._core._run.Task.__repr__",
- "trio._core._run.Task.parent_nursery",
- "trio._core._run.Task.eventual_parent_nursery",
- "trio._core._run.Task.child_nurseries",
- "trio._core._run.Task.iter_await_frames",
- "trio.lowlevel.checkpoint",
- "trio.lowlevel.current_task",
- "trio._core._parking_lot.ParkingLot.__len__",
- "trio._core._parking_lot.ParkingLot.__bool__",
- "trio._core._parking_lot.ParkingLot.unpark_all",
- "trio._core._parking_lot.ParkingLot.repark_all",
- "trio._core._parking_lot.ParkingLot.statistics",
"trio._core._unbounded_queue.UnboundedQueue.__repr__",
"trio._core._unbounded_queue.UnboundedQueue.qsize",
"trio._core._unbounded_queue.UnboundedQueue.empty",
@@ -268,7 +190,6 @@
"trio.lowlevel.add_instrument",
"trio.lowlevel.current_clock",
"trio.lowlevel.current_root_task",
- "trio.lowlevel.checkpoint_if_cancelled",
"trio.lowlevel.spawn_system_task",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
@@ -276,7 +197,6 @@
"trio.lowlevel.start_thread_soon",
"trio.lowlevel.start_guest_run",
"trio.lowlevel.open_process",
- "trio._unix_pipes.FdStream",
"trio.socket.fromfd",
"trio.socket.from_stdlib_socket",
"trio.socket.getprotobyname",
@@ -287,14 +207,6 @@
"trio.socket.set_custom_hostname_resolver",
"trio.socket.set_custom_socket_factory",
"trio.testing.wait_all_tasks_blocked",
- "trio._core._mock_clock.MockClock",
- "trio._core._mock_clock.MockClock.__init__",
- "trio._core._mock_clock.MockClock.__repr__",
- "trio._core._mock_clock.MockClock.rate",
- "trio._core._mock_clock.MockClock.autojump_threshold",
- "trio._core._mock_clock.MockClock.start_clock",
- "trio._core._mock_clock.MockClock.current_time",
- "trio._core._mock_clock.MockClock.deadline_to_sleep_time",
"trio._core._mock_clock.MockClock.jump",
"trio.testing.trio_test",
"trio.testing.assert_checkpoints",
@@ -302,7 +214,6 @@
"trio.testing.check_one_way_stream",
"trio.testing.check_two_way_stream",
"trio.testing.check_half_closeable_stream",
- "trio.testing._memory_streams.MemorySendStream",
"trio.testing._memory_streams.MemorySendStream.__init__",
"trio.testing._memory_streams.MemorySendStream.send_all",
"trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block",
@@ -313,7 +224,6 @@
"trio.testing._memory_streams.MemorySendStream.send_all_hook",
"trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block_hook",
"trio.testing._memory_streams.MemorySendStream.close_hook",
- "trio.testing._memory_streams.MemoryReceiveStream",
"trio.testing._memory_streams.MemoryReceiveStream.__init__",
"trio.testing._memory_streams.MemoryReceiveStream.receive_some",
"trio.testing._memory_streams.MemoryReceiveStream.close",
From 0050ca2a658b934dfdabec4668318dbf283de48f Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 4 Jul 2023 15:29:07 +0200
Subject: [PATCH 02/21] changes after review
---
trio/_abc.py | 9 ++++++---
trio/_core/_multierror.py | 18 ++++++++++++++----
trio/_core/_parking_lot.py | 6 +++---
trio/_core/_run.py | 11 ++++++++---
trio/_dtls.py | 31 +++++++++++++++++++++++++------
trio/_highlevel_socket.py | 4 ++--
trio/_socket.py | 10 ++++++++--
trio/_sync.py | 30 +++++++++++++++---------------
trio/_tests/verify_types.json | 9 ++-------
trio/_util.py | 10 +++++++++-
trio/testing/_fake_net.py | 13 +++++++++++--
11 files changed, 103 insertions(+), 48 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index 403e07a1de..036a20b567 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -1,9 +1,12 @@
from __future__ import annotations
from abc import ABCMeta, abstractmethod
-from typing import Generic, TypeVar
+from typing import Generic, TypeVar, TYPE_CHECKING
import trio
+if TYPE_CHECKING:
+ from typing_extensions import Self
+
# We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a
# __dict__ onto subclasses.
@@ -414,7 +417,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
"""
- def __aiter__(self) -> ReceiveStream:
+ def __aiter__(self) -> Self:
return self
async def __anext__(self) -> bytes | bytearray:
@@ -633,7 +636,7 @@ async def receive(self) -> ReceiveType:
"""
- def __aiter__(self) -> ReceiveChannel[ReceiveType]:
+ def __aiter__(self) -> Self:
return self
async def __anext__(self) -> ReceiveType:
diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py
index 9e69928162..6e964ba038 100644
--- a/trio/_core/_multierror.py
+++ b/trio/_core/_multierror.py
@@ -1,15 +1,19 @@
+from __future__ import annotations
import sys
import warnings
import attr
from trio._deprecate import warn_deprecated
+from typing import TYPE_CHECKING
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup, print_exception
else:
from traceback import print_exception
+if TYPE_CHECKING:
+ from types import TracebackType
################################################################
# MultiError
################################################################
@@ -130,11 +134,16 @@ class MultiErrorCatcher:
def __enter__(self):
pass
- def __exit__(self, etype, exc, tb):
- if exc is not None:
- filtered_exc = _filter_impl(self._handler, exc)
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> bool | None:
+ if exc_val is not None:
+ filtered_exc = _filter_impl(self._handler, exc_val)
- if filtered_exc is exc:
+ if filtered_exc is exc_val:
# Let the interpreter re-raise it
return False
if filtered_exc is None:
@@ -154,6 +163,7 @@ def __exit__(self, etype, exc, tb):
# delete references from locals to avoid creating cycles
# see test_MultiError_catch_doesnt_create_cyclic_garbage
del _, filtered_exc, value
+ return False
class MultiError(BaseExceptionGroup):
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index 42fc60614c..5ef22ca247 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -86,7 +86,7 @@
@attr.s(frozen=True, slots=True)
-class _ParkingLotStatistics:
+class ParkingLotStatistics:
tasks_waiting: int = attr.ib()
@@ -213,7 +213,7 @@ def repark_all(self, new_lot: ParkingLot) -> None:
"""
return self.repark(new_lot, count=len(self))
- def statistics(self) -> _ParkingLotStatistics:
+ def statistics(self) -> ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -222,4 +222,4 @@ def statistics(self) -> _ParkingLotStatistics:
:meth:`park` method.
"""
- return _ParkingLotStatistics(tasks_waiting=len(self._parked))
+ return ParkingLotStatistics(tasks_waiting=len(self._parked))
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 0664cdf577..b7a8359698 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -903,13 +903,18 @@ def __enter__(self) -> None:
"use 'async with open_nursery(...)', not 'with open_nursery(...)'"
)
- def __exit__(self) -> None: # pragma: no cover
+ def __exit__(
+ self, # pragma: no cover
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> bool:
assert False, """Never called, but should be defined"""
def open_nursery(
strict_exception_groups: bool | None = None,
-) -> NurseryManager:
+) -> AbstractAsyncContextManager[Nursery]:
"""Returns an async context manager which must be used to create a
new `Nursery`.
@@ -2456,7 +2461,7 @@ class _TaskStatusIgnored:
def __repr__(self) -> str:
return "TASK_STATUS_IGNORED"
- def started(self, value: None = None):
+ def started(self, value: Any = None):
pass
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 910637455a..9cb8a5d03c 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -6,6 +6,7 @@
# Hopefully they fix this before implementing DTLS 1.3, because it's a very different
# protocol, and it's probably impossible to pull tricks like we do here.
+from __future__ import annotations
import struct
import hmac
import os
@@ -14,12 +15,16 @@
import weakref
import errno
import warnings
+from typing import TYPE_CHECKING
import attr
import trio
from trio._util import NoPublicConstructor, Final
+if TYPE_CHECKING:
+ from types import TracebackType
+
MAX_UDP_PACKET_SIZE = 65527
@@ -809,7 +814,7 @@ def _check_replaced(self):
# DTLS where packets are all independent and can be lost anyway. We do at least need
# to handle receiving it properly though, which might be easier if we send it...
- def close(self):
+ def close(self) -> None:
"""Close this connection.
`DTLSChannel`\\s don't actually own any OS-level resources – the
@@ -833,8 +838,13 @@ def close(self):
def __enter__(self):
return self
- def __exit__(self, *args):
- self.close()
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ return self.close()
async def aclose(self):
"""Close this connection, but asynchronously.
@@ -1167,12 +1177,16 @@ def __del__(self):
f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self
)
- def close(self):
+ def close(self) -> None:
"""Close this socket, and all associated DTLS connections.
This object can also be used as a context manager.
"""
+ # Do nothing if this object was never fully constructed
+ if self.socket is None:
+ return
+
self._closed = True
self.socket.close()
for stream in list(self._streams.values()):
@@ -1182,8 +1196,13 @@ def close(self):
def __enter__(self):
return self
- def __exit__(self, *args):
- self.close()
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ return self.close()
def _check_closed(self):
if self._closed:
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index 5bd6b68f1c..b3641eb683 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -3,7 +3,7 @@
import errno
from contextlib import contextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
import trio
from . import socket as tsocket
@@ -130,7 +130,7 @@ async def send_eof(self) -> None:
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
- async def receive_some(self, max_bytes: int | None = None):
+ async def receive_some(self, max_bytes: int | None = None) -> Any:
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
diff --git a/trio/_socket.py b/trio/_socket.py
index 38a6fa0019..e25e035f06 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -15,6 +15,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable
from typing_extensions import Self
+ from types import TracebackType
# Usage:
@@ -485,8 +486,13 @@ def __dir__(self) -> Iterable[str]:
def __enter__(self) -> Self:
return self
- def __exit__(self, *exc_info: object) -> None:
- return self._sock.__exit__(*exc_info)
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ return self._sock.__exit__(exc_type, exc_val, exc_tb)
@property
def family(self) -> _stdlib_socket.AddressFamily:
diff --git a/trio/_sync.py b/trio/_sync.py
index 08eae512d2..54b4c64c47 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -13,11 +13,11 @@
if TYPE_CHECKING:
from ._core import Task
- from ._core._parking_lot import _ParkingLotStatistics
+ from ._core._parking_lot import ParkingLotStatistics
@attr.s(frozen=True)
-class _EventStatistics:
+class EventStatistics:
tasks_waiting: int = attr.ib()
@@ -82,7 +82,7 @@ def abort_fn(_):
await _core.wait_task_rescheduled(abort_fn)
- def statistics(self) -> _EventStatistics:
+ def statistics(self) -> EventStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -91,7 +91,7 @@ def statistics(self) -> _EventStatistics:
:meth:`wait` method.
"""
- return _EventStatistics(tasks_waiting=len(self._tasks))
+ return EventStatistics(tasks_waiting=len(self._tasks))
class AsyncContextManagerMixin:
@@ -105,7 +105,7 @@ async def __aexit__(self, *args: object) -> None:
@attr.s(frozen=True)
-class _CapacityLimiterStatistics:
+class CapacityLimiterStatistics:
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
borrowers: list[Task] = attr.ib()
@@ -327,7 +327,7 @@ def release_on_behalf_of(self, borrower: Task) -> None:
self._borrowers.remove(borrower)
self._wake_waiters()
- def statistics(self) -> _CapacityLimiterStatistics:
+ def statistics(self) -> CapacityLimiterStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -344,7 +344,7 @@ def statistics(self) -> _CapacityLimiterStatistics:
:meth:`acquire_on_behalf_of` methods.
"""
- return _CapacityLimiterStatistics(
+ return CapacityLimiterStatistics(
borrowed_tokens=len(self._borrowers),
total_tokens=self._total_tokens,
# Use a list instead of a frozenset just in case we start to allow
@@ -464,7 +464,7 @@ def release(self) -> None:
raise ValueError("semaphore released too many times")
self._value += 1
- def statistics(self) -> _ParkingLotStatistics:
+ def statistics(self) -> ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -477,7 +477,7 @@ def statistics(self) -> _ParkingLotStatistics:
@attr.s(frozen=True)
-class _LockStatistics:
+class LockStatistics:
locked: bool = attr.ib()
owner: Task | None = attr.ib()
tasks_waiting: int = attr.ib()
@@ -556,7 +556,7 @@ def release(self) -> None:
else:
self._owner = None
- def statistics(self) -> _LockStatistics:
+ def statistics(self) -> LockStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -568,7 +568,7 @@ def statistics(self) -> _LockStatistics:
:meth:`acquire` method.
"""
- return _LockStatistics(
+ return LockStatistics(
locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot)
)
@@ -651,9 +651,9 @@ class StrictFIFOLock(_LockImpl, metaclass=Final):
@attr.s(frozen=True)
-class _ConditionStatistics:
+class ConditionStatistics:
tasks_waiting: int = attr.ib()
- lock_statistics: _LockStatistics = attr.ib()
+ lock_statistics: LockStatistics = attr.ib()
class Condition(AsyncContextManagerMixin, metaclass=Final):
@@ -766,7 +766,7 @@ def notify_all(self) -> None:
raise RuntimeError("must hold the lock to notify")
self._lot.repark_all(self._lock._lot)
- def statistics(self) -> _ConditionStatistics:
+ def statistics(self) -> ConditionStatistics:
r"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -777,6 +777,6 @@ def statistics(self) -> _ConditionStatistics:
:class:`Lock`\s :meth:`~Lock.statistics` method.
"""
- return _ConditionStatistics(
+ return ConditionStatistics(
tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics()
)
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 53dc90d5d9..791c416e49 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 441,
- "withUnknownType": 143
+ "withKnownType": 430,
+ "withUnknownType": 138
},
"packageName": "trio",
"symbols": [
@@ -70,7 +70,6 @@
"trio._socket._SocketType.sendto",
"trio._socket._SocketType.sendmsg",
"trio._highlevel_socket.SocketStream.send_all",
- "trio._highlevel_socket.SocketStream.receive_some",
"trio._highlevel_socket.SocketStream.setsockopt",
"trio._highlevel_socket.SocketStream.getsockopt",
"trio._highlevel_socket.SocketListener",
@@ -124,17 +123,13 @@
"trio._ssl.SSLListener.aclose",
"trio._dtls.DTLSEndpoint.__init__",
"trio._dtls.DTLSEndpoint.__del__",
- "trio._dtls.DTLSEndpoint.close",
"trio._dtls.DTLSEndpoint.__enter__",
- "trio._dtls.DTLSEndpoint.__exit__",
"trio._dtls.DTLSEndpoint.serve",
"trio._dtls.DTLSEndpoint.connect",
"trio._dtls.DTLSEndpoint.socket",
"trio._dtls.DTLSEndpoint.incoming_packets_buffer",
"trio._dtls.DTLSChannel.__init__",
- "trio._dtls.DTLSChannel.close",
"trio._dtls.DTLSChannel.__enter__",
- "trio._dtls.DTLSChannel.__exit__",
"trio._dtls.DTLSChannel.aclose",
"trio._dtls.DTLSChannel.do_handshake",
"trio._dtls.DTLSChannel.send",
diff --git a/trio/_util.py b/trio/_util.py
index 89a2dea7de..090204f6e3 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -12,6 +12,9 @@
import trio
+if t.TYPE_CHECKING:
+ from types import TracebackType
+
# Equivalent to the C function raise(), which Python doesn't wrap
if os.name == "nt":
# On Windows, os.kill exists but is really weird.
@@ -188,7 +191,12 @@ def __enter__(self):
else:
self._held = True
- def __exit__(self, *args):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
self._held = False
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index 9df5ab5b6c..26348f2de2 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -6,15 +6,19 @@
# - TCP
# - UDP broadcast
+from __future__ import annotations
import trio
import attr
import ipaddress
import errno
import os
-from typing import Union, Optional
+from typing import Union, Optional, TYPE_CHECKING
from trio._util import Final, NoPublicConstructor
+if TYPE_CHECKING:
+ from types import TracebackType
+
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@@ -337,7 +341,12 @@ def setsockopt(self, level, item, value):
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
self.close()
async def send(self, data, flags=0):
From 8b8383ba39b553f479e8ea213cc983b0dd3a2856 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 4 Jul 2023 15:47:06 +0200
Subject: [PATCH 03/21] fix tests, reverting accidentally removing a default
param
---
trio/_core/_parking_lot.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index 5ef22ca247..ecf22652d8 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -69,7 +69,6 @@
# unpark is called.
#
# See: https://github.com/python-trio/trio/issues/53
-
from __future__ import annotations
import attr
@@ -138,13 +137,13 @@ def abort_fn(_):
def _pop_several(self, count: int | float) -> Iterator[Task]:
if isinstance(count, float):
- assert math.isinf(count)
+ assert math.isinf(count), "Cannot pop a non-integer number of tasks."
for _ in range(cast(int, min(count, len(self._parked)))):
task, _ = self._parked.popitem(last=False)
yield task
@_core.enable_ki_protection
- def unpark(self, *, count: int | float) -> list[Task]:
+ def unpark(self, *, count: int | float = 1) -> list[Task]:
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
From b1f9f6f6440a09cb9c9c231a18ba2bdab75bb93a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 4 Jul 2023 15:55:17 +0200
Subject: [PATCH 04/21] fix all __aexit
---
trio/_abc.py | 8 +++++++-
trio/_socket.py | 11 ++++++++---
trio/_sync.py | 8 +++++++-
trio/testing/_check_streams.py | 12 +++++++++++-
4 files changed, 33 insertions(+), 6 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index 036a20b567..21577310d6 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -6,6 +6,7 @@
if TYPE_CHECKING:
from typing_extensions import Self
+ from types import TracebackType
# We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a
@@ -260,7 +261,12 @@ async def aclose(self) -> None:
async def __aenter__(self) -> AsyncResource:
return self
- async def __aexit__(self, *args: object) -> None:
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
await self.aclose()
diff --git a/trio/_socket.py b/trio/_socket.py
index e25e035f06..8e98ce8fa3 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -39,8 +39,13 @@ def _is_blocking_io_error(self, exc):
async def __aenter__(self):
await trio.lowlevel.checkpoint_if_cancelled()
- async def __aexit__(self, etype, value, tb):
- if value is not None and self._is_blocking_io_error(value):
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> bool:
+ if exc_val is not None and self._is_blocking_io_error(exc_val):
# Discard the exception and fall through to the code below the
# block
return True
@@ -522,7 +527,7 @@ def close(self) -> None:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
- async def bind(self, address: tuple[Any, ...] | str | bytes) -> None:
+ async def bind(self, address: tuple[object, ...] | str | bytes) -> None:
address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
diff --git a/trio/_sync.py b/trio/_sync.py
index 54b4c64c47..164444771f 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -14,6 +14,7 @@
if TYPE_CHECKING:
from ._core import Task
from ._core._parking_lot import ParkingLotStatistics
+ from types import TracebackType
@attr.s(frozen=True)
@@ -100,7 +101,12 @@ async def __aenter__(self) -> None:
await self.acquire() # type: ignore[attr-defined]
@enable_ki_protection
- async def __aexit__(self, *args: object) -> None:
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
self.release() # type: ignore[attr-defined]
diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py
index 0206f1f737..6d8da15940 100644
--- a/trio/testing/_check_streams.py
+++ b/trio/testing/_check_streams.py
@@ -1,4 +1,5 @@
# Generic stream tests
+from __future__ import annotations
from contextlib import contextmanager
import random
@@ -7,6 +8,10 @@
from .._highlevel_generic import aclose_forcefully
from .._abc import SendStream, ReceiveStream, Stream, HalfCloseableStream
from ._checkpoints import assert_checkpoints
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from types import TracebackType
class _ForceCloseBoth:
@@ -16,7 +21,12 @@ def __init__(self, both):
async def __aenter__(self):
return self._both
- async def __aexit__(self, *args):
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
try:
await aclose_forcefully(self._both[0])
finally:
From ba6456557e5b3c88fd79d082d1508dacc15b1fb1 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 4 Jul 2023 16:03:10 +0200
Subject: [PATCH 05/21] rename exc_val to exc_value to match official data
model
---
trio/_abc.py | 2 +-
trio/_channel.py | 4 ++--
trio/_core/_multierror.py | 8 ++++----
trio/_core/_run.py | 2 +-
trio/_dtls.py | 4 ++--
trio/_socket.py | 8 ++++----
trio/_sync.py | 2 +-
trio/_util.py | 2 +-
trio/testing/_check_streams.py | 2 +-
trio/testing/_fake_net.py | 2 +-
10 files changed, 18 insertions(+), 18 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index 21577310d6..c20774dcad 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -264,7 +264,7 @@ async def __aenter__(self) -> AsyncResource:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
diff --git a/trio/_channel.py b/trio/_channel.py
index 3ad08b7109..f2e63363b4 100644
--- a/trio/_channel.py
+++ b/trio/_channel.py
@@ -249,7 +249,7 @@ def __enter__(self: SelfT) -> SelfT:
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
@@ -395,7 +395,7 @@ def __enter__(self: SelfT) -> SelfT:
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py
index 6e964ba038..8abd9730fb 100644
--- a/trio/_core/_multierror.py
+++ b/trio/_core/_multierror.py
@@ -137,13 +137,13 @@ def __enter__(self):
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
- if exc_val is not None:
- filtered_exc = _filter_impl(self._handler, exc_val)
+ if exc_value is not None:
+ filtered_exc = _filter_impl(self._handler, exc_value)
- if filtered_exc is exc_val:
+ if filtered_exc is exc_value:
# Let the interpreter re-raise it
return False
if filtered_exc is None:
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index b7a8359698..03799c1d2c 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -906,7 +906,7 @@ def __enter__(self) -> None:
def __exit__(
self, # pragma: no cover
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
assert False, """Never called, but should be defined"""
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 9cb8a5d03c..693759ef7a 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -841,7 +841,7 @@ def __enter__(self):
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
return self.close()
@@ -1199,7 +1199,7 @@ def __enter__(self):
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
return self.close()
diff --git a/trio/_socket.py b/trio/_socket.py
index 8e98ce8fa3..ff57a447cd 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -42,10 +42,10 @@ async def __aenter__(self):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
- if exc_val is not None and self._is_blocking_io_error(exc_val):
+ if exc_value is not None and self._is_blocking_io_error(exc_value):
# Discard the exception and fall through to the code below the
# block
return True
@@ -494,10 +494,10 @@ def __enter__(self) -> Self:
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
- return self._sock.__exit__(exc_type, exc_val, exc_tb)
+ return self._sock.__exit__(exc_type, exc_value, exc_tb)
@property
def family(self) -> _stdlib_socket.AddressFamily:
diff --git a/trio/_sync.py b/trio/_sync.py
index 164444771f..75898b5ef2 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -104,7 +104,7 @@ async def __aenter__(self) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.release() # type: ignore[attr-defined]
diff --git a/trio/_util.py b/trio/_util.py
index 090204f6e3..93dd1e38d0 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -194,7 +194,7 @@ def __enter__(self):
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._held = False
diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py
index 6d8da15940..a9a7125a48 100644
--- a/trio/testing/_check_streams.py
+++ b/trio/testing/_check_streams.py
@@ -24,7 +24,7 @@ async def __aenter__(self):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
try:
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index 26348f2de2..3b340d2a93 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -344,7 +344,7 @@ def __enter__(self):
def __exit__(
self,
exc_type: type[BaseException] | None,
- exc_val: BaseException | None,
+ exc_value: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
From adb87822380812675b3ba77888e93867e1bb7807 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 4 Jul 2023 16:04:23 +0200
Subject: [PATCH 06/21] rename exc_tb to traceback to match python data model
---
trio/_abc.py | 2 +-
trio/_channel.py | 4 ++--
trio/_core/_multierror.py | 2 +-
trio/_core/_run.py | 2 +-
trio/_dtls.py | 4 ++--
trio/_socket.py | 6 +++---
trio/_sync.py | 2 +-
trio/_util.py | 2 +-
trio/testing/_check_streams.py | 2 +-
trio/testing/_fake_net.py | 2 +-
10 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index c20774dcad..3a0ece749a 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -265,7 +265,7 @@ async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
await self.aclose()
diff --git a/trio/_channel.py b/trio/_channel.py
index f2e63363b4..a7c13dc9bd 100644
--- a/trio/_channel.py
+++ b/trio/_channel.py
@@ -250,7 +250,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
self.close()
@@ -396,7 +396,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
self.close()
diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py
index 8abd9730fb..584179f595 100644
--- a/trio/_core/_multierror.py
+++ b/trio/_core/_multierror.py
@@ -138,7 +138,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> bool | None:
if exc_value is not None:
filtered_exc = _filter_impl(self._handler, exc_value)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 03799c1d2c..8f2264f52a 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -907,7 +907,7 @@ def __exit__(
self, # pragma: no cover
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> bool:
assert False, """Never called, but should be defined"""
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 693759ef7a..ad6e492b97 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -842,7 +842,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
return self.close()
@@ -1200,7 +1200,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
return self.close()
diff --git a/trio/_socket.py b/trio/_socket.py
index ff57a447cd..6e9abe0cac 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -43,7 +43,7 @@ async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> bool:
if exc_value is not None and self._is_blocking_io_error(exc_value):
# Discard the exception and fall through to the code below the
@@ -495,9 +495,9 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
- return self._sock.__exit__(exc_type, exc_value, exc_tb)
+ return self._sock.__exit__(exc_type, exc_value, traceback)
@property
def family(self) -> _stdlib_socket.AddressFamily:
diff --git a/trio/_sync.py b/trio/_sync.py
index 75898b5ef2..3fdc38b302 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -105,7 +105,7 @@ async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
self.release() # type: ignore[attr-defined]
diff --git a/trio/_util.py b/trio/_util.py
index 93dd1e38d0..9631226e88 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -195,7 +195,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
self._held = False
diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py
index a9a7125a48..0e1a4f8f14 100644
--- a/trio/testing/_check_streams.py
+++ b/trio/testing/_check_streams.py
@@ -25,7 +25,7 @@ async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
try:
await aclose_forcefully(self._both[0])
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index 3b340d2a93..646b7af7a6 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -345,7 +345,7 @@ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
- exc_tb: TracebackType | None,
+ traceback: TracebackType | None,
) -> None:
self.close()
From 3c9583c82905d52763fd9a01548a8d204ee7d89e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 4 Jul 2023 16:11:17 +0200
Subject: [PATCH 07/21] fix format check CI
---
trio/_socket.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index 6e9abe0cac..0446e998ae 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,7 +5,7 @@
import select
import socket as _stdlib_socket
from functools import wraps as _wraps
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
import idna as _idna
From 0f29a1a34364e2588227e4128f9153efb0d07bd5 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 5 Jul 2023 17:58:04 +0200
Subject: [PATCH 08/21] changes after review from TeamSpen210
---
trio/_core/_run.py | 25 ++++++++++++++-----------
trio/_highlevel_socket.py | 4 ++--
trio/_socket.py | 8 +++++++-
trio/_tests/verify_types.json | 5 ++---
4 files changed, 25 insertions(+), 17 deletions(-)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 8f2264f52a..7fdd203f9c 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -898,18 +898,21 @@ async def __aexit__(
# see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage
del _, combined_error_from_nursery, value, new_exc
- def __enter__(self) -> None:
- raise RuntimeError(
- "use 'async with open_nursery(...)', not 'with open_nursery(...)'"
- )
+ # make sure these raise errors in static analysis if called
+ if not TYPE_CHECKING:
- def __exit__(
- self, # pragma: no cover
- exc_type: type[BaseException] | None,
- exc_value: BaseException | None,
- traceback: TracebackType | None,
- ) -> bool:
- assert False, """Never called, but should be defined"""
+ def __enter__(self) -> NoReturn:
+ raise RuntimeError(
+ "use 'async with open_nursery(...)', not 'with open_nursery(...)'"
+ )
+
+ def __exit__(
+ self, # pragma: no cover
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> NoReturn:
+ raise AssertionError("Never called, but should be defined")
def open_nursery(
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index b3641eb683..b63902d887 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -3,7 +3,7 @@
import errno
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
import trio
from . import socket as tsocket
@@ -130,7 +130,7 @@ async def send_eof(self) -> None:
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
- async def receive_some(self, max_bytes: int | None = None) -> Any:
+ async def receive_some(self, max_bytes: int | None = None) -> bytes:
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
diff --git a/trio/_socket.py b/trio/_socket.py
index 0446e998ae..8c70a0a3a1 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -701,7 +701,13 @@ async def connect(self, address):
# recv
################################################################
- recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable)
+ if TYPE_CHECKING:
+
+ async def recv(self, buffersize: int, flags: int = 0) -> bytes:
+ ...
+
+ else:
+ recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable)
################################################################
# recv_into
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 791c416e49..a9bfc663c1 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 430,
- "withUnknownType": 138
+ "withKnownType": 431,
+ "withUnknownType": 137
},
"packageName": "trio",
"symbols": [
@@ -60,7 +60,6 @@
"trio._socket._SocketType.__getattr__",
"trio._socket._SocketType.accept",
"trio._socket._SocketType.connect",
- "trio._socket._SocketType.recv",
"trio._socket._SocketType.recv_into",
"trio._socket._SocketType.recvfrom",
"trio._socket._SocketType.recvfrom_into",
From c2012a79215c8f7fc8d1a9b61a5508b36ee67844 Mon Sep 17 00:00:00 2001
From: John Litborn <11260241+jakkdl@users.noreply.github.com>
Date: Thu, 6 Jul 2023 15:16:13 +0200
Subject: [PATCH 09/21] Apply suggestions from code review
Co-authored-by: Spencer Brown
---
trio/_core/_parking_lot.py | 9 +++++++--
trio/_core/_run.py | 2 +-
2 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index ecf22652d8..bfcf08a69a 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -137,8 +137,13 @@ def abort_fn(_):
def _pop_several(self, count: int | float) -> Iterator[Task]:
if isinstance(count, float):
- assert math.isinf(count), "Cannot pop a non-integer number of tasks."
- for _ in range(cast(int, min(count, len(self._parked)))):
+ if math.isinf(count):
+ count = len(self._parked)
+ else:
+ raise ValueError("Cannot pop a non-integer number of tasks.")
+ else:
+ count = min(count, len(self._parked))
+ for _ in range(count):
task, _ = self._parked.popitem(last=False)
yield task
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 7fdd203f9c..a3fec21c6f 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -861,7 +861,7 @@ class NurseryManager:
strict_exception_groups: bool = attr.ib(default=False)
@enable_ki_protection
- async def __aenter__(self):
+ async def __aenter__(self) -> Nursery:
self._scope = CancelScope()
self._scope.__enter__()
self._nursery = Nursery._create(
From 9a0c3e28281dd21485d3459ee47b2365ab1a1eb8 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 6 Jul 2023 15:20:00 +0200
Subject: [PATCH 10/21] fixes after review
---
trio/_core/_run.py | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index a3fec21c6f..ddf7221a6f 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -50,9 +50,7 @@
if TYPE_CHECKING:
# An unfortunate name collision here with trio._util.Final
from typing_extensions import Final as FinalT
- from collections.abc import Coroutine
from types import FrameType
- import outcome
import contextvars
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
@@ -1169,7 +1167,10 @@ def __del__(self) -> None:
@attr.s(eq=False, hash=False, repr=False, slots=True)
class Task(metaclass=NoPublicConstructor):
_parent_nursery: Nursery | None = attr.ib()
- coro: Coroutine[Any, outcome.Outcome[object], Any] = attr.ib()
+ # could be typed as `Coroutine[Any, outcome.Outcome[object], Any]` but
+ # the code does a lot of dynamic introspection on it and as long as it passes
+ # those it's fine.
+ coro: Any = attr.ib()
_runner = attr.ib()
name: str = attr.ib()
context: contextvars.Context = attr.ib()
@@ -1191,8 +1192,10 @@ class Task(metaclass=NoPublicConstructor):
_abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib(
default=None
)
- # could possible be set with a TypeVar
- custom_sleep_data: Any = attr.ib(default=None)
+ # Typed as `object`, forcing users to do an isinstance check each time. Since
+ # anything touching the task could have set this, it's not really going to be
+ # safe to assume that this had the value you saw it with last.
+ custom_sleep_data: object = attr.ib(default=None)
# For introspection and nursery.start()
_child_nurseries: list[Nursery] = attr.ib(factory=list)
@@ -1264,11 +1267,11 @@ def print_stack_for_task(task):
if hasattr(coro, "cr_frame"):
# A real coroutine
yield coro.cr_frame, coro.cr_frame.f_lineno
- coro = coro.cr_await # type: ignore # TODO
+ coro = coro.cr_await
elif hasattr(coro, "gi_frame"):
# A generator decorated with @types.coroutine
yield coro.gi_frame, coro.gi_frame.f_lineno
- coro = coro.gi_yieldfrom # type: ignore # TODO
+ coro = coro.gi_yieldfrom
elif coro.__class__.__name__ in [
"async_generator_athrow",
"async_generator_asend",
From 3094d066e25ad60b690c7611c1cf1395ecfce168 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 6 Jul 2023 16:14:24 +0200
Subject: [PATCH 11/21] attempt to fix readthedocs build, (temp) ignore errors
on custom_sleep_data
---
trio/__init__.py | 4 ++++
trio/_channel.py | 8 ++++----
trio/_core/__init__.py | 2 +-
trio/_core/_run.py | 3 ++-
trio/_tests/verify_types.json | 4 ++--
trio/lowlevel.py | 1 +
6 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/trio/__init__.py b/trio/__init__.py
index 40aa3c430d..728fe4e0d7 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -50,11 +50,15 @@
from ._sync import (
Event as Event,
+ EventStatistics as EventStatistics,
CapacityLimiter as CapacityLimiter,
Semaphore as Semaphore,
Lock as Lock,
StrictFIFOLock as StrictFIFOLock,
Condition as Condition,
+ ConditionStatistics as ConditionStatistics,
+ CapacityLimiterStatistics as CapacityLimiterStatistics,
+ LockStatistics as LockStatistics,
)
from ._highlevel_generic import (
diff --git a/trio/_channel.py b/trio/_channel.py
index a7c13dc9bd..05a685c3c9 100644
--- a/trio/_channel.py
+++ b/trio/_channel.py
@@ -178,7 +178,7 @@ def send_nowait(self, value: SendType) -> None:
if self._state.receive_tasks:
assert not self._state.data
task, _ = self._state.receive_tasks.popitem(last=False)
- task.custom_sleep_data._tasks.remove(task)
+ task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
trio.lowlevel.reschedule(task, Value(value))
elif len(self._state.data) < self._state.max_buffer_size:
self._state.data.append(value)
@@ -278,7 +278,7 @@ def close(self) -> None:
if self._state.open_send_channels == 0:
assert not self._state.send_tasks
for task in self._state.receive_tasks:
- task.custom_sleep_data._tasks.remove(task)
+ task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
trio.lowlevel.reschedule(task, Error(trio.EndOfChannel()))
self._state.receive_tasks.clear()
@@ -315,7 +315,7 @@ def receive_nowait(self) -> ReceiveType:
raise trio.ClosedResourceError
if self._state.send_tasks:
task, value = self._state.send_tasks.popitem(last=False)
- task.custom_sleep_data._tasks.remove(task)
+ task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
trio.lowlevel.reschedule(task)
self._state.data.append(value)
# Fall through
@@ -424,7 +424,7 @@ def close(self) -> None:
if self._state.open_receive_channels == 0:
assert not self._state.receive_tasks
for task in self._state.send_tasks:
- task.custom_sleep_data._tasks.remove(task)
+ task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError()))
self._state.send_tasks.clear()
self._state.data.clear()
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index f9919b8323..a57f2e7021 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -64,7 +64,7 @@
from ._entry_queue import TrioToken
-from ._parking_lot import ParkingLot
+from ._parking_lot import ParkingLot, ParkingLotStatistics
from ._unbounded_queue import UnboundedQueue
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index ddf7221a6f..2e1905029f 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -47,10 +47,11 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
+from types import FrameType
+
if TYPE_CHECKING:
# An unfortunate name collision here with trio._util.Final
from typing_extensions import Final as FinalT
- from types import FrameType
import contextvars
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index a9bfc663c1..c1d4aab66a 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,10 +7,10 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8737864077669902,
+ "completenessScore": 0.8747993579454254,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 540,
+ "withKnownType": 545,
"withUnknownType": 77
},
"ignoreUnknownTypesFromImports": true,
diff --git a/trio/lowlevel.py b/trio/lowlevel.py
index b7f4f3a725..55838fb8f6 100644
--- a/trio/lowlevel.py
+++ b/trio/lowlevel.py
@@ -25,6 +25,7 @@
checkpoint as checkpoint,
current_task as current_task,
ParkingLot as ParkingLot,
+ ParkingLotStatistics as ParkingLotStatistics,
UnboundedQueue as UnboundedQueue,
RunVar as RunVar,
TrioToken as TrioToken,
From f83373b9d2d984496ee2d61ec7c6cfd69386ae01 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 12:29:54 +0200
Subject: [PATCH 12/21] revent custom_sleep_data to Any, add docstring to some
Statistics to see if that makes readthedocs happy
---
trio/_channel.py | 8 ++++----
trio/_core/_parking_lot.py | 11 ++++++++++-
trio/_core/_run.py | 2 +-
trio/_sync.py | 22 ++++++++++++++++++++++
4 files changed, 37 insertions(+), 6 deletions(-)
diff --git a/trio/_channel.py b/trio/_channel.py
index 05a685c3c9..a7c13dc9bd 100644
--- a/trio/_channel.py
+++ b/trio/_channel.py
@@ -178,7 +178,7 @@ def send_nowait(self, value: SendType) -> None:
if self._state.receive_tasks:
assert not self._state.data
task, _ = self._state.receive_tasks.popitem(last=False)
- task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
+ task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Value(value))
elif len(self._state.data) < self._state.max_buffer_size:
self._state.data.append(value)
@@ -278,7 +278,7 @@ def close(self) -> None:
if self._state.open_send_channels == 0:
assert not self._state.send_tasks
for task in self._state.receive_tasks:
- task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
+ task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Error(trio.EndOfChannel()))
self._state.receive_tasks.clear()
@@ -315,7 +315,7 @@ def receive_nowait(self) -> ReceiveType:
raise trio.ClosedResourceError
if self._state.send_tasks:
task, value = self._state.send_tasks.popitem(last=False)
- task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
+ task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task)
self._state.data.append(value)
# Fall through
@@ -424,7 +424,7 @@ def close(self) -> None:
if self._state.open_receive_channels == 0:
assert not self._state.receive_tasks
for task in self._state.send_tasks:
- task.custom_sleep_data._tasks.remove(task) # type: ignore[attr-defined]
+ task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError()))
self._state.send_tasks.clear()
self._state.data.clear()
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index bfcf08a69a..272dfa2e1c 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -73,7 +73,7 @@
import attr
from collections import OrderedDict
-from typing import TYPE_CHECKING, cast
+from typing import TYPE_CHECKING
from collections.abc import Iterator
import math
@@ -86,6 +86,15 @@
@attr.s(frozen=True, slots=True)
class ParkingLotStatistics:
+ """An object containing debugging information for a ParkingLot.
+
+ Currently the following fields are defined:
+
+ * ``tasks_waiting`` (int): The number of tasks blocked on this lot's
+ :meth:`park` method.
+
+ """
+
tasks_waiting: int = attr.ib()
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 2e1905029f..12c1f8581e 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -1196,7 +1196,7 @@ class Task(metaclass=NoPublicConstructor):
# Typed as `object`, forcing users to do an isinstance check each time. Since
# anything touching the task could have set this, it's not really going to be
# safe to assume that this had the value you saw it with last.
- custom_sleep_data: object = attr.ib(default=None)
+ custom_sleep_data: Any = attr.ib(default=None)
# For introspection and nursery.start()
_child_nurseries: list[Nursery] = attr.ib(factory=list)
diff --git a/trio/_sync.py b/trio/_sync.py
index 3fdc38b302..e43e646065 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -484,6 +484,18 @@ def statistics(self) -> ParkingLotStatistics:
@attr.s(frozen=True)
class LockStatistics:
+ """An object containing debugging information for a Lock.
+
+ Currently the following fields are defined:
+
+ * ``locked`` (boolean): indicating whether the lock is held.
+ * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
+ or None if the lock is not held.
+ * ``tasks_waiting`` (int): The number of tasks blocked on this lock's
+ :meth:`acquire` method.
+
+ """
+
locked: bool = attr.ib()
owner: Task | None = attr.ib()
tasks_waiting: int = attr.ib()
@@ -658,6 +670,16 @@ class StrictFIFOLock(_LockImpl, metaclass=Final):
@attr.s(frozen=True)
class ConditionStatistics:
+ r"""Return an object containing debugging information for a Condition.
+
+ Currently the following fields are defined:
+
+ * ``tasks_waiting`` (int): The number of tasks blocked on this condition's
+ :meth:`wait` method.
+ * ``lock_statistics``: The result of calling the underlying
+ :class:`Lock`\s :meth:`~Lock.statistics` method.
+
+ """
tasks_waiting: int = attr.ib()
lock_statistics: LockStatistics = attr.ib()
From ac4b357a86c86cd8f129fe2cb3fbedf53a2b7249 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 14:45:49 +0200
Subject: [PATCH 13/21] fix CI errors
---
docs/source/conf.py | 12 ++++++++++++
docs/source/reference-core.rst | 12 ++++++++++++
docs/source/reference-lowlevel.rst | 2 ++
trio/_core/_parking_lot.py | 9 +++++----
trio/_sync.py | 19 ++++++++++---------
5 files changed, 41 insertions(+), 13 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index cfac66576b..2d8e2e8111 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -62,10 +62,22 @@
("py:obj", "trio._abc.SendType"),
("py:obj", "trio._abc.T"),
("py:obj", "trio._abc.T_resource"),
+ ("py:class", "types.FrameType"),
]
autodoc_inherit_docstrings = False
default_role = "obj"
+# These have incorrect __module__ set in stdlib and give the error
+# `py:class reference target not found`
+# Some of the nitpick_ignore's above can probably be fixed with this.
+# See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798
+autodoc_type_aliases = {
+ # aliasing doesn't actually fix the warning for types.FrameType, but displaying
+ # "types.FrameType" is more helpfun than just "frame"
+ "FrameType": "types.FrameType",
+}
+
+
# XX hack the RTD theme until
# https://github.com/rtfd/sphinx_rtd_theme/pull/382
# is shipped (should be in the release after 0.2.4)
diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst
index 922ae4680e..4f4f4d62b9 100644
--- a/docs/source/reference-core.rst
+++ b/docs/source/reference-core.rst
@@ -1096,6 +1096,8 @@ Broadcasting an event with :class:`Event`
.. autoclass:: Event
:members:
+.. autoclass:: EventStatistics
+ :members:
.. _channels:
@@ -1456,6 +1458,16 @@ don't have any special access to Trio's internals.)
.. autoclass:: Condition
:members:
+These primitives return statistics objects that can be inspected.
+
+.. autoclass:: CapacityLimiterStatistics
+ :members:
+
+.. autoclass:: LockStatistics
+ :members:
+
+.. autoclass:: ConditionStatistics
+ :members:
.. _async-generators:
diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst
index 815cff2ddf..bacebff5ad 100644
--- a/docs/source/reference-lowlevel.rst
+++ b/docs/source/reference-lowlevel.rst
@@ -378,6 +378,8 @@ Wait queue abstraction
:members:
:undoc-members:
+.. autoclass:: ParkingLotStatistics
+ :members:
Low-level checkpoint functions
------------------------------
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index 272dfa2e1c..da10e7cf91 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -71,11 +71,12 @@
# See: https://github.com/python-trio/trio/issues/53
from __future__ import annotations
-import attr
+import math
from collections import OrderedDict
-from typing import TYPE_CHECKING
from collections.abc import Iterator
-import math
+from typing import TYPE_CHECKING
+
+import attr
from .. import _core
from .._util import Final
@@ -91,7 +92,7 @@ class ParkingLotStatistics:
Currently the following fields are defined:
* ``tasks_waiting`` (int): The number of tasks blocked on this lot's
- :meth:`park` method.
+ :meth:`trio.lowlevel.ParkingLot.park` method.
"""
diff --git a/trio/_sync.py b/trio/_sync.py
index e43e646065..feb5825e47 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -1,23 +1,24 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
import math
+from typing import TYPE_CHECKING
import attr
import trio
from . import _core
-from ._core import enable_ki_protection, ParkingLot
+from ._core import ParkingLot, enable_ki_protection
from ._util import Final
if TYPE_CHECKING:
+ from types import TracebackType
+
from ._core import Task
from ._core._parking_lot import ParkingLotStatistics
- from types import TracebackType
-@attr.s(frozen=True)
+@attr.s(frozen=True, slots=True)
class EventStatistics:
tasks_waiting: int = attr.ib()
@@ -110,7 +111,7 @@ async def __aexit__(
self.release() # type: ignore[attr-defined]
-@attr.s(frozen=True)
+@attr.s(frozen=True, slots=True)
class CapacityLimiterStatistics:
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
@@ -482,7 +483,7 @@ def statistics(self) -> ParkingLotStatistics:
return self._lot.statistics()
-@attr.s(frozen=True)
+@attr.s(frozen=True, slots=True)
class LockStatistics:
"""An object containing debugging information for a Lock.
@@ -492,7 +493,7 @@ class LockStatistics:
* ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
or None if the lock is not held.
* ``tasks_waiting`` (int): The number of tasks blocked on this lock's
- :meth:`acquire` method.
+ :meth:`trio.Lock.acquire` method.
"""
@@ -668,14 +669,14 @@ class StrictFIFOLock(_LockImpl, metaclass=Final):
"""
-@attr.s(frozen=True)
+@attr.s(frozen=True, slots=True)
class ConditionStatistics:
r"""Return an object containing debugging information for a Condition.
Currently the following fields are defined:
* ``tasks_waiting`` (int): The number of tasks blocked on this condition's
- :meth:`wait` method.
+ :meth:`trio.Condition.wait` method.
* ``lock_statistics``: The result of calling the underlying
:class:`Lock`\s :meth:`~Lock.statistics` method.
From 77a370636dde766ec4f6d742c67ca5ef565d8db6 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 14:51:27 +0200
Subject: [PATCH 14/21] don't require *Statistics to be Final
---
trio/_tests/test_exports.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index 3ab0016386..e51bbe31f5 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -492,4 +492,8 @@ def test_classes_are_final():
continue
# ... insert other special cases here ...
+ # don't care about the *Statistics classes
+ if name.endswith("Statistics"):
+ continue
+
assert isinstance(class_, _util.Final)
From 50ce72fb0b9b4536659eda2a23cb986ab25adeb7 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 15:00:33 +0200
Subject: [PATCH 15/21] fix CI
---
trio/_core/_multierror.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py
index 584179f595..3c6ebb789f 100644
--- a/trio/_core/_multierror.py
+++ b/trio/_core/_multierror.py
@@ -1,11 +1,12 @@
from __future__ import annotations
+
import sys
import warnings
+from typing import TYPE_CHECKING
import attr
from trio._deprecate import warn_deprecated
-from typing import TYPE_CHECKING
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup, print_exception
From f34787be6e286c2bcae3c70ac31b792eb2173eba Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 15:25:47 +0200
Subject: [PATCH 16/21] fix codecov and formatting
---
trio/_core/_run.py | 4 ++--
trio/_core/_tests/test_parking_lot.py | 3 +++
trio/_dtls.py | 2 +-
trio/_util.py | 4 +---
4 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 1e9d309003..4ac3f93339 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -904,8 +904,8 @@ def __enter__(self) -> NoReturn:
"use 'async with open_nursery(...)', not 'with open_nursery(...)'"
)
- def __exit__(
- self, # pragma: no cover
+ def __exit__( # pragma: no cover
+ self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py
index db3fc76709..32442e5a0f 100644
--- a/trio/_core/_tests/test_parking_lot.py
+++ b/trio/_core/_tests/test_parking_lot.py
@@ -72,6 +72,9 @@ async def waiter(i, lot):
)
lot.unpark_all()
+ with pytest.raises(TypeError):
+ lot.unpark(count=1.5)
+
async def cancellable_waiter(name, lot, scopes, record):
with _core.CancelScope() as scope:
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 57ce638a5d..da99481e06 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -1185,7 +1185,7 @@ def close(self) -> None:
"""
# Do nothing if this object was never fully constructed
- if self.socket is None:
+ if self.socket is None: # pragma: no cover
return
self._closed = True
diff --git a/trio/_util.py b/trio/_util.py
index cbe892f41d..0a0795fc15 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -9,12 +9,10 @@
import typing as t
from abc import ABCMeta
from functools import update_wrapper
+from types import TracebackType
import trio
-if t.TYPE_CHECKING:
- from types import TracebackType
-
# Equivalent to the C function raise(), which Python doesn't wrap
if os.name == "nt":
# On Windows, os.kill exists but is really weird.
From a8751bbc5f5dd345dfb8c88e50b0feff4c1b3cd4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 15:38:23 +0200
Subject: [PATCH 17/21] ValueError, not TypeError
---
trio/_core/_tests/test_parking_lot.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py
index 32442e5a0f..3f03fdbade 100644
--- a/trio/_core/_tests/test_parking_lot.py
+++ b/trio/_core/_tests/test_parking_lot.py
@@ -72,7 +72,7 @@ async def waiter(i, lot):
)
lot.unpark_all()
- with pytest.raises(TypeError):
+ with pytest.raises(ValueError):
lot.unpark(count=1.5)
From ed05e130bc4abdb6e10c0a7605b9041219628538 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 7 Jul 2023 16:06:55 +0200
Subject: [PATCH 18/21] fix pragma: no cover
---
trio/_core/_run.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 4ac3f93339..49f55dee3e 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -904,12 +904,12 @@ def __enter__(self) -> NoReturn:
"use 'async with open_nursery(...)', not 'with open_nursery(...)'"
)
- def __exit__( # pragma: no cover
+ def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
- ) -> NoReturn:
+ ) -> NoReturn: # pragma: no cover
raise AssertionError("Never called, but should be defined")
From 67166a5cc287fe9d32ce80f2d4340bacbfb64993 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 10 Jul 2023 12:31:19 +0200
Subject: [PATCH 19/21] fixes from CI review, mainly from Zac-HD
---
docs/source/conf.py | 2 +-
trio/_core/_run.py | 19 ++++++++++---------
trio/_highlevel_generic.py | 3 ++-
trio/_socket.py | 4 ++--
4 files changed, 15 insertions(+), 13 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2d8e2e8111..68a5a22a81 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -73,7 +73,7 @@
# See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798
autodoc_type_aliases = {
# aliasing doesn't actually fix the warning for types.FrameType, but displaying
- # "types.FrameType" is more helpfun than just "frame"
+ # "types.FrameType" is more helpful than just "frame"
"FrameType": "types.FrameType",
}
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 49f55dee3e..fa7d651da1 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -10,7 +10,7 @@
import threading
import warnings
from collections import deque
-from collections.abc import Callable, Iterator
+from collections.abc import Callable, Coroutine, Iterator
from contextlib import AbstractAsyncContextManager, contextmanager
from contextvars import copy_context
from heapq import heapify, heappop, heappush
@@ -1167,10 +1167,7 @@ def __del__(self) -> None:
@attr.s(eq=False, hash=False, repr=False, slots=True)
class Task(metaclass=NoPublicConstructor):
_parent_nursery: Nursery | None = attr.ib()
- # could be typed as `Coroutine[Any, outcome.Outcome[object], Any]` but
- # the code does a lot of dynamic introspection on it and as long as it passes
- # those it's fine.
- coro: Any = attr.ib()
+ coro: Coroutine[Any, Outcome[object], Any] = attr.ib()
_runner = attr.ib()
name: str = attr.ib()
context: contextvars.Context = attr.ib()
@@ -1262,7 +1259,8 @@ def print_stack_for_task(task):
print("".join(ss.format()))
"""
- coro = self.coro
+ # ignore static typing as we're doing lots of dynamic introspection
+ coro: Any = self.coro
while coro is not None:
if hasattr(coro, "cr_frame"):
# A real coroutine
@@ -1297,7 +1295,7 @@ def print_stack_for_task(task):
# Don't change this directly; instead, use _activate_cancel_status().
_cancel_status: CancelStatus = attr.ib(default=None, repr=False)
- def _activate_cancel_status(self, cancel_status: CancelStatus):
+ def _activate_cancel_status(self, cancel_status: CancelStatus) -> None:
if self._cancel_status is not None:
self._cancel_status._tasks.remove(self)
self._cancel_status = cancel_status
@@ -1312,8 +1310,11 @@ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None:
# rescheduling itself (hopefully eventually calling reraise to raise
# the given exception, but not necessarily).
- # TODO: what happens if _abort_func is None?
- success = self._abort_func(raise_cancel) # type: ignore
+ # This is only called by the functions immediately below, which both check
+ # `self.abort_func is not None`.
+ assert self._abort_func is not None, "FATAL INTERNAL ERROR"
+
+ success = self._abort_func(raise_cancel)
if type(success) is not Abort:
raise TrioInternalError("abort function must return Abort enum")
# We only attempt to abort once per blocking call, regardless of
diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py
index 6f98d6e18a..c126d4c8aa 100644
--- a/trio/_highlevel_generic.py
+++ b/trio/_highlevel_generic.py
@@ -102,7 +102,8 @@ async def send_eof(self) -> None:
else:
return await self.send_stream.aclose()
- async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
+ # we intentionally accept more types from the caller than we support returning
+ async def receive_some(self, max_bytes: int | None = None) -> bytes:
"""Calls ``self.receive_stream.receive_some``."""
return await self.receive_stream.receive_some(max_bytes)
diff --git a/trio/_socket.py b/trio/_socket.py
index 68dbb87c48..eaf0e04d15 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -487,8 +487,7 @@ def __getattr__(self, name):
raise AttributeError(name)
def __dir__(self) -> Iterable[str]:
- yield from super().__dir__()
- yield from self._forward
+ return [*super().__dir__(), *self._forward]
def __enter__(self) -> Self:
return self
@@ -538,6 +537,7 @@ async def bind(self, address: tuple[object, ...] | str | bytes) -> None:
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
+ # remove the `type: ignore` when run.sync is typed.
return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return]
else:
# POSIX actually says that bind can return EWOULDBLOCK and
From 89f308f2dade05ec268dd4687f5e12033e0049a4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 11 Jul 2023 13:21:21 +0200
Subject: [PATCH 20/21] small fixes after review from A5Rocks
---
trio/_abc.py | 2 +-
trio/_core/_parking_lot.py | 4 ++--
trio/_core/_run.py | 7 ++-----
trio/_dtls.py | 2 ++
trio/_highlevel_generic.py | 4 +++-
trio/_sync.py | 30 +++++++++++++++++++++++++++++-
trio/_tests/verify_types.json | 11 +++++------
trio/lowlevel.py | 1 -
8 files changed, 44 insertions(+), 17 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index c88b5f559e..2a1721db13 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -260,7 +260,7 @@ async def aclose(self) -> None:
"""
- async def __aenter__(self) -> AsyncResource:
+ async def __aenter__(self) -> Self:
return self
async def __aexit__(
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index da10e7cf91..74708433da 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -179,7 +179,7 @@ def unpark_all(self) -> list[Task]:
return self.unpark(count=len(self))
@_core.enable_ki_protection
- def repark(self, new_lot: ParkingLot, *, count: int = 1) -> None:
+ def repark(self, new_lot: ParkingLot, *, count: int | float = 1) -> None:
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
@@ -209,7 +209,7 @@ async def main():
Args:
new_lot (ParkingLot): the parking lot to move tasks to.
- count (int): the number of tasks to move.
+ count (int|math.inf): the number of tasks to move.
"""
if not isinstance(new_lot, ParkingLot):
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index fa7d651da1..585dc4aa41 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -48,9 +48,9 @@
from types import FrameType
if TYPE_CHECKING:
- # An unfortunate name collision here with trio._util.Final
import contextvars
+ # An unfortunate name collision here with trio._util.Final
from typing_extensions import Final as FinalT
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
@@ -1189,9 +1189,6 @@ class Task(metaclass=NoPublicConstructor):
_abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib(
default=None
)
- # Typed as `object`, forcing users to do an isinstance check each time. Since
- # anything touching the task could have set this, it's not really going to be
- # safe to assume that this had the value you saw it with last.
custom_sleep_data: Any = attr.ib(default=None)
# For introspection and nursery.start()
@@ -2468,7 +2465,7 @@ class _TaskStatusIgnored:
def __repr__(self) -> str:
return "TASK_STATUS_IGNORED"
- def started(self, value: Any = None):
+ def started(self, value: object = None) -> None:
pass
diff --git a/trio/_dtls.py b/trio/_dtls.py
index da99481e06..722a9499f8 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -1132,6 +1132,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10):
global SSL
from OpenSSL import SSL
+ # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed
+ # as trio.socket.SocketType and `is not None` checks can be removed.
self.socket = None # for __del__, in case the next line raises
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py
index c126d4c8aa..e1ac378c6a 100644
--- a/trio/_highlevel_generic.py
+++ b/trio/_highlevel_generic.py
@@ -98,7 +98,9 @@ async def send_eof(self) -> None:
"""
if hasattr(self.send_stream, "send_eof"):
- return await self.send_stream.send_eof() # type: ignore
+ # send_stream.send_eof() is not defined in Trio, this should maybe be
+ # redesigned so it's possible to type it.
+ return await self.send_stream.send_eof() # type: ignore[no-any-return]
else:
return await self.send_stream.aclose()
diff --git a/trio/_sync.py b/trio/_sync.py
index feb5825e47..1ccff947a0 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -20,6 +20,15 @@
@attr.s(frozen=True, slots=True)
class EventStatistics:
+ """An object containing debugging information.
+
+ Currently the following fields are defined:
+
+ * ``tasks_waiting``: The number of tasks blocked on this event's
+ :meth:`wait` method.
+
+ """
+
tasks_waiting: int = attr.ib()
@@ -96,6 +105,8 @@ def statistics(self) -> EventStatistics:
return EventStatistics(tasks_waiting=len(self._tasks))
+# TODO: type this with a Protocol to get rid of type: ignore, see
+# https://github.com/python-trio/trio/pull/2682#discussion_r1259097422
class AsyncContextManagerMixin:
@enable_ki_protection
async def __aenter__(self) -> None:
@@ -113,6 +124,23 @@ async def __aexit__(
@attr.s(frozen=True, slots=True)
class CapacityLimiterStatistics:
+ """An object containing debugging information.
+
+ Currently the following fields are defined:
+
+ * ``borrowed_tokens``: The number of tokens currently borrowed from
+ the sack.
+ * ``total_tokens``: The total number of tokens in the sack. Usually
+ this will be larger than ``borrowed_tokens``, but it's possibly for
+ it to be smaller if :attr:`total_tokens` was recently decreased.
+ * ``borrowers``: A list of all tasks or other entities that currently
+ hold a token.
+ * ``tasks_waiting``: The number of tasks blocked on this
+ :class:`CapacityLimiter`\'s :meth:`acquire` or
+ :meth:`acquire_on_behalf_of` methods.
+
+ """
+
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
borrowers: list[Task] = attr.ib()
@@ -671,7 +699,7 @@ class StrictFIFOLock(_LockImpl, metaclass=Final):
@attr.s(frozen=True, slots=True)
class ConditionStatistics:
- r"""Return an object containing debugging information for a Condition.
+ r"""An object containing debugging information for a Condition.
Currently the following fields are defined:
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index cb2253ec27..9d7d7aa912 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8747993579454254,
+ "completenessScore": 0.8764044943820225,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 545,
- "withUnknownType": 77
+ "withKnownType": 546,
+ "withUnknownType": 76
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 431,
- "withUnknownType": 137
+ "withKnownType": 433,
+ "withUnknownType": 135
},
"packageName": "trio",
"symbols": [
@@ -73,7 +73,6 @@
"trio._core._mock_clock.MockClock.jump",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
- "trio._core._run._TaskStatusIgnored.started",
"trio._core._unbounded_queue.UnboundedQueue.__aiter__",
"trio._core._unbounded_queue.UnboundedQueue.__anext__",
"trio._core._unbounded_queue.UnboundedQueue.__repr__",
diff --git a/trio/lowlevel.py b/trio/lowlevel.py
index 4b0f114b2d..54f4ef3141 100644
--- a/trio/lowlevel.py
+++ b/trio/lowlevel.py
@@ -7,7 +7,6 @@
import sys
import typing as _t
-# Generally available symbols
# Generally available symbols
from ._core import (
Abort as Abort,
From 253684326ca071ea8d8190d2b32d2d5fb006bb2f Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 11 Jul 2023 13:26:25 +0200
Subject: [PATCH 21/21] fix read the docs references
---
trio/_sync.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/trio/_sync.py b/trio/_sync.py
index 1ccff947a0..5a7f240d5e 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -25,7 +25,7 @@ class EventStatistics:
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this event's
- :meth:`wait` method.
+ :meth:`trio.Event.wait` method.
"""
@@ -132,12 +132,12 @@ class CapacityLimiterStatistics:
the sack.
* ``total_tokens``: The total number of tokens in the sack. Usually
this will be larger than ``borrowed_tokens``, but it's possibly for
- it to be smaller if :attr:`total_tokens` was recently decreased.
+ it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased.
* ``borrowers``: A list of all tasks or other entities that currently
hold a token.
* ``tasks_waiting``: The number of tasks blocked on this
- :class:`CapacityLimiter`\'s :meth:`acquire` or
- :meth:`acquire_on_behalf_of` methods.
+ :class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or
+ :meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods.
"""