Skip to content

Commit

Permalink
avoid calls to make_current() and make_clear() by using asyncio.run i…
Browse files Browse the repository at this point in the history
…n LoopRunner

Closes dask#6784
  • Loading branch information
graingert committed Feb 1, 2023
1 parent d74f500 commit f121d04
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 276 deletions.
39 changes: 6 additions & 33 deletions distributed/actor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import abc
import asyncio
import functools
import sys
import threading
from collections.abc import Generator
from dataclasses import dataclass
from datetime import timedelta
from typing import Generic, Literal, NoReturn, TypeVar
Expand All @@ -13,43 +13,16 @@

from distributed.client import Future
from distributed.protocol import to_serialize
from distributed.utils import iscoroutinefunction, sync, thread_state
from distributed.utils import LateLoopEvent, iscoroutinefunction, sync, thread_state
from distributed.utils_comm import WrappedKey
from distributed.worker import get_client, get_worker

_T = TypeVar("_T")

if sys.version_info >= (3, 9):
from collections.abc import Awaitable, Generator
from collections.abc import Awaitable
else:
from typing import Awaitable, Generator

if sys.version_info >= (3, 10):
from asyncio import Event as _LateLoopEvent
else:
# In python 3.10 asyncio.Lock and other primitives no longer support
# passing a loop kwarg to bind to a loop running in another thread
# e.g. calling from Client(asynchronous=False). Instead the loop is bound
# as late as possible: when calling any methods that wait on or wake
# Future instances. See: https://bugs.python.org/issue42392
class _LateLoopEvent:
def __init__(self) -> None:
self._event: asyncio.Event | None = None

def set(self) -> None:
if self._event is None:
self._event = asyncio.Event()
from typing import Awaitable

self._event.set()

def is_set(self) -> bool:
return self._event is not None and self._event.is_set()

async def wait(self) -> bool:
if self._event is None:
self._event = asyncio.Event()

return await self._event.wait()
_T = TypeVar("_T")


class Actor(WrappedKey):
Expand Down Expand Up @@ -318,7 +291,7 @@ def unwrap(self) -> NoReturn:
class ActorFuture(BaseActorFuture[_T]):
def __init__(self, io_loop: IOLoop):
self._io_loop = io_loop
self._event = _LateLoopEvent()
self._event = LateLoopEvent()
self._out: _Error | _OK[_T] | None = None

def __await__(self) -> Generator[object, None, _T]:
Expand Down
13 changes: 8 additions & 5 deletions distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import uuid
import warnings
from collections.abc import Awaitable
from contextlib import suppress
from inspect import isawaitable
from typing import Any
Expand Down Expand Up @@ -205,16 +206,17 @@ async def _close(self):

self.status = Status.closed

def close(self, timeout=None):
def close(self, timeout: float | None = None) -> Awaitable[None] | None:
# If the cluster is already closed, we're already done
if self.status == Status.closed:
if self.asynchronous:
return NoOpAwaitable()
else:
return
return None

with suppress(RuntimeError): # loop closed during process shutdown
try:
return self.sync(self._close, callback_timeout=timeout)
except RuntimeError: # loop closed during process shutdown
return None

def __del__(self, _warn=warnings.warn):
if getattr(self, "status", Status.closed) != Status.closed:
Expand Down Expand Up @@ -522,7 +524,8 @@ def __enter__(self):
return self.sync(self.__aenter__)

def __exit__(self, exc_type, exc_value, traceback):
return self.sync(self.__aexit__, exc_type, exc_value, traceback)
aw = self.close()
assert aw is None

def __await__(self):
return self
Expand Down
11 changes: 7 additions & 4 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,15 @@ def __init__(
self.sync(self._correct_state)
except Exception:
self.sync(self.close)
self._loop_runner.stop()
raise

def close(self, timeout: float | None = None) -> Awaitable[None] | None:
aw = super().close(timeout)
if not self.asynchronous:
self._loop_runner.stop()
return aw

async def _start(self):
while self.status == Status.starting:
await asyncio.sleep(0.01)
Expand Down Expand Up @@ -471,10 +478,6 @@ async def __aenter__(self):
assert self.status == Status.running
return self

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._loop_runner.stop()

def _threads_per_worker(self) -> int:
"""Return the number of threads per worker for new workers"""
if not self.new_spec: # pragma: no cover
Expand Down
5 changes: 1 addition & 4 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,6 @@ async def test_no_more_workers_than_tasks():
assert len(cluster.scheduler.workers) <= 1


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_basic_no_loop(cleanup):
loop = None
try:
Expand All @@ -294,8 +292,7 @@ def test_basic_no_loop(cleanup):
assert future.result() == 2
loop = cluster.loop
finally:
if loop is not None:
loop.add_callback(loop.stop)
assert loop is None or not loop.asyncio_loop.is_running()


@pytest.mark.flaky(condition=LINUX, reruns=10, reruns_delay=5)
Expand Down
2 changes: 0 additions & 2 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_spec_sync(loop):
assert result == 11


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_loop_started_in_constructor(cleanup):
# test that SpecCluster.__init__ starts a loop in another thread
cluster = SpecCluster(worker_spec, scheduler=scheduler, loop=None)
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
get_client,
wait,
)
from distributed.actor import _LateLoopEvent
from distributed.metrics import time
from distributed.utils import LateLoopEvent
from distributed.utils_test import cluster, gen_cluster


Expand Down Expand Up @@ -261,7 +261,7 @@ def test_sync(client):
def test_timeout(client):
class Waiter:
def __init__(self):
self.event = _LateLoopEvent()
self.event = LateLoopEvent()

async def set(self):
self.event.set()
Expand Down Expand Up @@ -553,7 +553,7 @@ def sleep(self, time):
async def test_waiter(c, s, a, b):
class Waiter:
def __init__(self):
self.event = _LateLoopEvent()
self.event = LateLoopEvent()

async def set(self):
self.event.set()
Expand Down
55 changes: 7 additions & 48 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,7 @@
from distributed.metrics import time
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
from distributed.sizeof import sizeof
from distributed.utils import (
NoOpAwaitable,
get_mp_context,
is_valid_xml,
open_port,
sync,
tmp_text,
)
from distributed.utils import get_mp_context, is_valid_xml, open_port, sync, tmp_text
from distributed.utils_test import (
NO_AMM,
BlockedGatherDep,
Expand Down Expand Up @@ -2210,27 +2203,8 @@ async def test_multi_client(s, a, b):
await asyncio.sleep(0.01)


@contextmanager
def _pristine_loop():
IOLoop.clear_instance()
IOLoop.clear_current()
loop = IOLoop()
loop.make_current()
assert IOLoop.current() is loop
try:
yield loop
finally:
try:
loop.close(all_fds=True)
except (KeyError, ValueError):
pass
IOLoop.clear_instance()
IOLoop.clear_current()


def long_running_client_connection(address):
with _pristine_loop():
c = Client(address)
with Client(address, loop=None) as c:
x = c.submit(lambda x: x + 1, 10)
x.result()
sleep(100)
Expand Down Expand Up @@ -2893,8 +2867,6 @@ async def test_startup_close_startup(s, a, b):
pass


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_startup_close_startup_sync(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
Expand Down Expand Up @@ -5622,23 +5594,12 @@ async def test_future_auto_inform(c, s, a, b):
await asyncio.sleep(0.01)


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:clear_current is deprecated:DeprecationWarning")
def test_client_async_before_loop_starts(cleanup):
async def close():
async with client:
pass

with _pristine_loop() as loop:
with pytest.warns(
DeprecationWarning,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is deprecated",
):
client = Client(asynchronous=True, loop=loop)
assert client.asynchronous
assert isinstance(client.close(), NoOpAwaitable)
loop.run_sync(close) # TODO: client.close() does not unset global client
with pytest.raises(
RuntimeError,
match=r"Constructing LoopRunner\(asynchronous=True\) without a running loop is not supported",
):
client = Client(asynchronous=True, loop=None)


@pytest.mark.slow
Expand Down Expand Up @@ -7059,8 +7020,6 @@ async def test_workers_collection_restriction(c, s, a, b):
assert a.data and not b.data


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_get_client_functions_spawn_clusters(c, s, a):
# see gh4565
Expand Down
6 changes: 0 additions & 6 deletions distributed/tests/test_client_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import contextlib

import pytest

from distributed import Client, LocalCluster
from distributed.utils import LoopRunner

Expand All @@ -29,16 +27,12 @@ def _check_cluster_and_client_loop(loop):


# Test if Client stops LoopRunner on close.
@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_close_loop_sync_start_new_loop(cleanup):
with _check_loop_runner():
_check_cluster_and_client_loop(loop=None)


# Test if Client stops LoopRunner on close.
@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_close_loop_sync_use_running_loop(cleanup):
with _check_loop_runner():
# Start own loop or use current thread's one.
Expand Down
Loading

0 comments on commit f121d04

Please sign in to comment.