Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Pyright detection of constructor kwargs in attrs-based data classes #2945

Merged
merged 11 commits into from
Feb 5, 2024
Merged
12 changes: 6 additions & 6 deletions notes-to-self/time-wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
import attr


@attr.s(repr=False)
@attr.define(repr=False, slots=False)
class Options:
listen1_early = attr.ib(default=None)
listen1_middle = attr.ib(default=None)
listen1_late = attr.ib(default=None)
server = attr.ib(default=None)
listen2 = attr.ib(default=None)
listen1_early = attr.field(default=None)
listen1_middle = attr.field(default=None)
listen1_late = attr.field(default=None)
server = attr.field(default=None)
listen2 = attr.field(default=None)

def set(self, which, sock):
value = getattr(self, which)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ requires-python = ">=3.8"
dependencies = [
# attrs 19.2.0 adds `eq` option to decorators
# attrs 20.1.0 adds @frozen
"attrs >= 20.1.0",
"attrs >= 21.2.0",
mikenerone marked this conversation as resolved.
Show resolved Hide resolved
"sortedcontainers",
"idna",
"outcome",
Expand Down
44 changes: 22 additions & 22 deletions src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,27 @@ def __init__(self, max_buffer_size: int | float): # noqa: PYI041
open_memory_channel = generic_function(_open_memory_channel)


@attr.s(frozen=True, slots=True)
@attr.frozen
class MemoryChannelStats:
current_buffer_used: int = attr.ib()
max_buffer_size: int | float = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
tasks_waiting_receive: int = attr.ib()
current_buffer_used: int = attr.field()
max_buffer_size: int | float = attr.field()
open_send_channels: int = attr.field()
open_receive_channels: int = attr.field()
tasks_waiting_send: int = attr.field()
tasks_waiting_receive: int = attr.field()


@attr.s(slots=True)
@attr.define
class MemoryChannelState(Generic[T]):
max_buffer_size: int | float = attr.ib()
data: deque[T] = attr.ib(factory=deque)
max_buffer_size: int | float = attr.field()
data: deque[T] = attr.field(factory=deque)
# Counts of open endpoints using this state
open_send_channels: int = attr.ib(default=0)
open_receive_channels: int = attr.ib(default=0)
open_send_channels: int = attr.field(default=0)
open_receive_channels: int = attr.field(default=0)
# {task: value}
send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict)
send_tasks: OrderedDict[Task, T] = attr.field(factory=OrderedDict)
# {task: None}
receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict)
receive_tasks: OrderedDict[Task, None] = attr.field(factory=OrderedDict)

def statistics(self) -> MemoryChannelStats:
return MemoryChannelStats(
Expand All @@ -143,14 +143,14 @@ def statistics(self) -> MemoryChannelStats:


@final
@attr.s(eq=False, repr=False)
@attr.define(eq=False, repr=False, slots=False)
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType] = attr.ib()
_closed: bool = attr.ib(default=False)
_state: MemoryChannelState[SendType] = attr.field()
_closed: bool = attr.field(default=False)
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks: set[Task] = attr.ib(factory=set)
_tasks: set[Task] = attr.field(factory=set)

def __attrs_post_init__(self) -> None:
self._state.open_send_channels += 1
Expand Down Expand Up @@ -286,11 +286,11 @@ async def aclose(self) -> None:


@final
@attr.s(eq=False, repr=False)
@attr.define(eq=False, repr=False, slots=False)
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[ReceiveType] = attr.ib()
_closed: bool = attr.ib(default=False)
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)
_state: MemoryChannelState[ReceiveType] = attr.field()
_closed: bool = attr.field(default=False)
_tasks: set[trio._core._run.Task] = attr.field(factory=set)

def __attrs_post_init__(self) -> None:
self._state.open_receive_channels += 1
Expand Down
10 changes: 6 additions & 4 deletions src/trio/_core/_asyncgens.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_ASYNC_GEN_SET = set


@attr.s(eq=False, slots=True)
@attr.define(eq=False)
class AsyncGenerators:
# Async generators are added to this set when first iterated. Any
# left after the main task exits will be closed before trio.run()
Expand All @@ -35,14 +35,16 @@ class AsyncGenerators:
# asyncgens after the system nursery has been closed, it's a
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attr.ib(factory=_WEAK_ASYNC_GEN_SET)
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attr.field(
factory=_WEAK_ASYNC_GEN_SET
)

# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
# init task starting end-of-run asyncgen finalization.
trailing_needs_finalize: _ASYNC_GEN_SET = attr.ib(factory=_ASYNC_GEN_SET)
trailing_needs_finalize: _ASYNC_GEN_SET = attr.field(factory=_ASYNC_GEN_SET)

prev_hooks: sys._asyncgen_hooks = attr.ib(init=False)
prev_hooks: sys._asyncgen_hooks = attr.field(init=False)

def install_hooks(self, runner: _run.Runner) -> None:
def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None:
Expand Down
16 changes: 8 additions & 8 deletions src/trio/_core/_entry_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Job = Tuple[Function, Tuple[object, ...]]


@attr.s(slots=True)
@attr.define
class EntryQueue:
# This used to use a queue.Queue. but that was broken, because Queues are
# implemented in Python, and not reentrant -- so it was thread-safe, but
Expand All @@ -28,11 +28,11 @@ class EntryQueue:
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
queue: deque[Job] = attr.ib(factory=deque)
idempotent_queue: dict[Job, None] = attr.ib(factory=dict)
queue: deque[Job] = attr.field(factory=deque)
idempotent_queue: dict[Job, None] = attr.field(factory=dict)

wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
done: bool = attr.ib(default=False)
wakeup: WakeupSocketpair = attr.field(factory=WakeupSocketpair)
done: bool = attr.field(default=False)
# Must be a reentrant lock, because it's acquired from signal handlers.
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
# lock is effectively *disabled* when we enter from signal context. The
Expand All @@ -41,7 +41,7 @@ class EntryQueue:
# main thread -- it just might happen at some inconvenient place. But if
# you look at the one place where the main thread holds the lock, it's
# just to make 1 assignment, so that's atomic WRT a signal anyway.
lock: threading.RLock = attr.ib(factory=threading.RLock)
lock: threading.RLock = attr.field(factory=threading.RLock)

async def task(self) -> None:
assert _core.currently_ki_protected()
Expand Down Expand Up @@ -146,7 +146,7 @@ def run_sync_soon(


@final
@attr.s(eq=False, hash=False, slots=True)
@attr.define(eq=False, hash=False)
class TrioToken(metaclass=NoPublicConstructor):
"""An opaque object representing a single call to :func:`trio.run`.

Expand All @@ -166,7 +166,7 @@ class TrioToken(metaclass=NoPublicConstructor):

"""

_reentry_queue: EntryQueue = attr.ib()
_reentry_queue: EntryQueue = attr.field()

def run_sync_soon(
self,
Expand Down
26 changes: 13 additions & 13 deletions src/trio/_core/_io_epoll.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from .._file_io import _HasFileNo


@attr.s(slots=True, eq=False)
@attr.define(eq=False)
class EpollWaiters:
read_task: Task | None = attr.ib(default=None)
write_task: Task | None = attr.ib(default=None)
current_flags: int = attr.ib(default=0)
read_task: Task | None = attr.field(default=None)
write_task: Task | None = attr.field(default=None)
current_flags: int = attr.field(default=0)


assert not TYPE_CHECKING or sys.platform == "linux"
Expand All @@ -33,11 +33,11 @@ class EpollWaiters:
EventResult: TypeAlias = "list[tuple[int, int]]"


@attr.s(slots=True, eq=False, frozen=True)
@attr.frozen(eq=False)
class _EpollStatistics:
tasks_waiting_read: int = attr.ib()
tasks_waiting_write: int = attr.ib()
backend: Literal["epoll"] = attr.ib(init=False, default="epoll")
tasks_waiting_read: int = attr.field()
tasks_waiting_write: int = attr.field()
backend: Literal["epoll"] = attr.field(init=False, default="epoll")


# Some facts about epoll
Expand Down Expand Up @@ -198,15 +198,15 @@ class _EpollStatistics:
# wanted to about how epoll works.


@attr.s(slots=True, eq=False, hash=False)
@attr.define(eq=False, hash=False)
class EpollIOManager:
_epoll: select.epoll = attr.ib(factory=select.epoll)
_epoll: select.epoll = attr.field(factory=select.epoll)
# {fd: EpollWaiters}
_registered: defaultdict[int, EpollWaiters] = attr.ib(
_registered: defaultdict[int, EpollWaiters] = attr.field(
factory=lambda: defaultdict(EpollWaiters)
)
_force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.ib(default=None)
_force_wakeup: WakeupSocketpair = attr.field(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.field(default=None)

def __attrs_post_init__(self) -> None:
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
Expand Down
20 changes: 10 additions & 10 deletions src/trio/_core/_io_kqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@
EventResult: TypeAlias = "list[select.kevent]"


@attr.s(slots=True, eq=False, frozen=True)
@attr.frozen(eq=False)
class _KqueueStatistics:
tasks_waiting: int = attr.ib()
monitors: int = attr.ib()
backend: Literal["kqueue"] = attr.ib(init=False, default="kqueue")
tasks_waiting: int = attr.field()
monitors: int = attr.field()
backend: Literal["kqueue"] = attr.field(init=False, default="kqueue")


@attr.s(slots=True, eq=False)
@attr.define(eq=False)
class KqueueIOManager:
_kqueue: select.kqueue = attr.ib(factory=select.kqueue)
_kqueue: select.kqueue = attr.field(factory=select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
_registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib(
factory=dict
_registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = (
attr.field(factory=dict)
)
_force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.ib(default=None)
_force_wakeup: WakeupSocketpair = attr.field(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.field(default=None)

def __attrs_post_init__(self) -> None:
force_wakeup_event = select.kevent(
Expand Down
42 changes: 21 additions & 21 deletions src/trio/_core/_io_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,22 @@
# To avoid this, we have to coalesce all the operations on a single socket
# into one, and when the set of waiters changes we have to throw away the old
# operation and start a new one.
@attr.s(slots=True, eq=False)
@attr.define(eq=False)
class AFDWaiters:
read_task: _core.Task | None = attr.ib(default=None)
write_task: _core.Task | None = attr.ib(default=None)
current_op: AFDPollOp | None = attr.ib(default=None)
read_task: _core.Task | None = attr.field(default=None)
write_task: _core.Task | None = attr.field(default=None)
current_op: AFDPollOp | None = attr.field(default=None)

Check warning on line 249 in src/trio/_core/_io_windows.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_core/_io_windows.py#L247-L249

Added lines #L247 - L249 were not covered by tests


# We also need to bundle up all the info for a single op into a standalone
# object, because we need to keep all these objects alive until the operation
# finishes, even if we're throwing it away.
@attr.s(slots=True, eq=False, frozen=True)
@attr.frozen(eq=False)
class AFDPollOp:
lpOverlapped: CData = attr.ib()
poll_info: Any = attr.ib()
waiters: AFDWaiters = attr.ib()
afd_group: AFDGroup = attr.ib()
lpOverlapped: CData = attr.field()
poll_info: Any = attr.field()
waiters: AFDWaiters = attr.field()
afd_group: AFDGroup = attr.field()

Check warning on line 260 in src/trio/_core/_io_windows.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_core/_io_windows.py#L257-L260

Added lines #L257 - L260 were not covered by tests


# The Windows kernel has a weird issue when using AFD handles. If you have N
Expand All @@ -271,22 +271,22 @@
MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite


@attr.s(slots=True, eq=False)
@attr.define(eq=False)
class AFDGroup:
size: int = attr.ib()
handle: Handle = attr.ib()
size: int = attr.field()
handle: Handle = attr.field()

Check warning on line 277 in src/trio/_core/_io_windows.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_core/_io_windows.py#L276-L277

Added lines #L276 - L277 were not covered by tests


assert not TYPE_CHECKING or sys.platform == "win32"


@attr.s(slots=True, eq=False, frozen=True)
@attr.frozen(eq=False)
class _WindowsStatistics:
tasks_waiting_read: int = attr.ib()
tasks_waiting_write: int = attr.ib()
tasks_waiting_overlapped: int = attr.ib()
completion_key_monitors: int = attr.ib()
backend: Literal["windows"] = attr.ib(init=False, default="windows")
tasks_waiting_read: int = attr.field()
tasks_waiting_write: int = attr.field()
tasks_waiting_overlapped: int = attr.field()
completion_key_monitors: int = attr.field()
backend: Literal["windows"] = attr.field(init=False, default="windows")

Check warning on line 289 in src/trio/_core/_io_windows.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_core/_io_windows.py#L285-L289

Added lines #L285 - L289 were not covered by tests


# Maximum number of events to dequeue from the completion port on each pass
Expand Down Expand Up @@ -405,10 +405,10 @@
return handle


@attr.s(frozen=True)
@attr.frozen(slots=False)
class CompletionKeyEventInfo:
lpOverlapped: CData = attr.ib()
dwNumberOfBytesTransferred: int = attr.ib()
lpOverlapped: CData = attr.field()
dwNumberOfBytesTransferred: int = attr.field()

Check warning on line 411 in src/trio/_core/_io_windows.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_core/_io_windows.py#L410-L411

Added lines #L410 - L411 were not covered by tests


class WindowsIOManager:
Expand Down
4 changes: 2 additions & 2 deletions src/trio/_core/_ki.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def __call__(self, f: CallableT, /) -> CallableT:
disable_ki_protection.__name__ = "disable_ki_protection"


@attr.s
@attr.define(slots=False)
class KIManager:
handler: Callable[[int, types.FrameType | None], None] | None = attr.ib(
handler: Callable[[int, types.FrameType | None], None] | None = attr.field(
default=None
)

Expand Down
14 changes: 7 additions & 7 deletions src/trio/_core/_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ class _NoValue: ...


@final
@attr.s(eq=False, hash=False, slots=True)
@attr.define(eq=False, hash=False)
class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
_var: RunVar[T] = attr.ib()
previous_value: T | type[_NoValue] = attr.ib(default=_NoValue)
redeemed: bool = attr.ib(default=False, init=False)
_var: RunVar[T] = attr.field()
previous_value: T | type[_NoValue] = attr.field(default=_NoValue)
redeemed: bool = attr.field(default=False, init=False)

@classmethod
def _empty(cls, var: RunVar[T]) -> RunVarToken[T]:
return cls._create(var)


@final
@attr.s(eq=False, hash=False, slots=True, repr=False)
@attr.define(eq=False, hash=False, repr=False)
class RunVar(Generic[T]):
"""The run-local variant of a context variable.

Expand All @@ -38,8 +38,8 @@ class RunVar(Generic[T]):

"""

_name: str = attr.ib()
_default: T | type[_NoValue] = attr.ib(default=_NoValue)
_name: str = attr.field()
_default: T | type[_NoValue] = attr.field(default=_NoValue)

def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
Expand Down
Loading
Loading