Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove deprecated code calls to IOLoop.make_current() #7240

Merged
merged 6 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 112 additions & 115 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import asyncio
import contextlib
import errno
import functools
import logging
import multiprocessing
import os
import shutil
import tempfile
Expand All @@ -11,14 +14,13 @@
import warnings
import weakref
from collections.abc import Collection
from contextlib import suppress
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, Callable, ClassVar, Literal

import tornado.util
graingert marked this conversation as resolved.
Show resolved Hide resolved
from toolz import merge
from tornado import gen
from tornado.ioloop import IOLoop

import dask
Expand All @@ -45,7 +47,6 @@
from distributed.protocol import pickle
from distributed.security import Security
from distributed.utils import (
TimeoutError,
get_ip,
get_mp_context,
json_load_robust,
Expand Down Expand Up @@ -303,14 +304,15 @@ async def _unregister(self, timeout=10):
if worker_address is None:
return

allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed)
with suppress(allowed_errors):
try:
await asyncio.wait_for(
self.scheduler.unregister(
address=self.worker_address, stimulus_id=f"nanny-close-{time()}"
),
timeout,
)
except (asyncio.TimeoutError, CommClosedError, OSError, RPCClosed):
pass

@property
def worker_address(self):
Expand Down Expand Up @@ -425,7 +427,7 @@ async def instantiate(self) -> Status:
result = await asyncio.wait_for(
self.process.start(), self.death_timeout
)
except TimeoutError:
except asyncio.TimeoutError:
logger.error(
"Timed out connecting Nanny '%s' to scheduler '%s'",
self,
Expand Down Expand Up @@ -496,7 +498,7 @@ async def _():

try:
await asyncio.wait_for(_(), timeout)
except TimeoutError:
except asyncio.TimeoutError:
logger.error(
f"Restart timed out after {timeout}s; returning before finished"
)
Expand Down Expand Up @@ -679,18 +681,18 @@ async def start(self) -> Status:
uid = uuid.uuid4().hex

self.process = AsyncProcess(
target=self._run,
name="Dask Worker process (from Nanny)",
kwargs=dict(
worker_kwargs=self.worker_kwargs,
target=functools.partial(
self._run,
silence_logs=self.silence_logs,
init_result_q=self.init_result_q,
child_stop_q=self.child_stop_q,
uid=uid,
Worker=self.Worker,
worker_factory=functools.partial(self.Worker, **self.worker_kwargs),
env=self.env,
config=self.config,
),
name="Dask Worker process (from Nanny)",
kwargs=dict(),
)
self.process.daemon = dask.config.get("distributed.worker.daemon", default=True)
self.process.set_exit_callback(self._on_exit)
Expand Down Expand Up @@ -860,86 +862,66 @@ async def _wait_until_connected(self, uid):
@classmethod
def _run(
cls,
worker_kwargs,
silence_logs,
init_result_q,
child_stop_q,
uid,
env,
config,
Worker,
): # pragma: no cover
try:
os.environ.update(env)
dask.config.refresh()
dask.config.set(config)

from dask.multiprocessing import default_initializer

default_initializer()

if silence_logs:
logger.setLevel(silence_logs)

IOLoop.clear_instance()
loop = IOLoop()
loop.make_current()
worker = Worker(**worker_kwargs)

async def do_stop(
timeout=5, executor_wait=True, reason="workerprocess-stop"
):
try:
await worker.close(
nanny=False,
executor_wait=executor_wait,
timeout=timeout,
reason=reason,
)
finally:
loop.stop()

def watch_stop_q():
"""
Wait for an incoming stop message and then stop the
worker cleanly.
"""
try:
msg = child_stop_q.get()
except (TypeError, OSError, EOFError):
logger.error("Worker process died unexpectedly")
msg = {"op": "stop"}
finally:
child_stop_q.close()
assert msg["op"] == "stop", msg
del msg["op"]
loop.add_callback(do_stop, **msg)

thread = threading.Thread(
target=watch_stop_q, name="Nanny stop queue watch"
silence_logs: bool,
init_result_q: multiprocessing.Queue,
child_stop_q: multiprocessing.Queue,
uid: str,
env: dict,
config: dict,
worker_factory: Callable[[], Worker],
) -> None: # pragma: no cover
async def do_stop(
*,
worker: Worker,
timeout: float = 5,
executor_wait: bool = True,
reason: str = "workerprocess-stop",
) -> None:
await worker.close(
nanny=False,
executor_wait=executor_wait,
timeout=timeout,
reason=reason,
)
thread.daemon = True
thread.start()

async def run():
"""
Try to start worker and inform parent of outcome.
"""
try:
await worker
except Exception as e:
logger.exception("Failed to start worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good
# measure)
sync_sleep(cls._init_msg_interval * 2)
else:
def watch_stop_q(loop: IOLoop, worker: Worker) -> None:
"""
Wait for an incoming stop message and then stop the
worker cleanly.
"""
try:
msg = child_stop_q.get()
except (TypeError, OSError, EOFError):
logger.error("Worker process died unexpectedly")
msg = {"op": "stop"}
finally:
child_stop_q.close()
assert msg["op"] == "stop", msg
del msg["op"]
loop.add_callback(do_stop, worker=worker, **msg)

async def run() -> None:
"""
Try to start worker and inform parent of outcome.
"""
failure_type: str | None = "initialize"
try:
worker = worker_factory()
failure_type = "start"
thread = threading.Thread(
target=functools.partial(
watch_stop_q,
worker=worker,
loop=IOLoop.current(),
),
name="Nanny stop queue watch",
daemon=True,
)
thread.start()
stack.callback(thread.join, timeout=2)
async with worker:
failure_type = None

try:
assert worker.address
except ValueError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is outside of the changes made in this PR, but do you have an idea why the ValueError would pop up here and why it's okay to ignore it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's this ValueError

raise ValueError("cannot get address of non-running Server")
not sure it's safe to ignore tbh

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fjetter: Do you happen to know whether we should keep ignoring this error?

Expand All @@ -955,34 +937,49 @@ async def run():
init_result_q.close()
await worker.finished()
logger.info("Worker closed")

except Exception as e:
logger.exception("Failed to initialize Worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least one
# interval for the outside to pick up this message. Otherwise we
# arrive in a race condition where the process cleanup wipes the
# queue before the exception can be properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good measure)
sync_sleep(cls._init_msg_interval * 2)
else:
try:
loop.run_sync(run)
except (TimeoutError, gen.TimeoutError):
# Loop was stopped before wait_until_closed() returned, ignore
except (tornado.util.TimeoutError, asyncio.TimeoutError):
graingert marked this conversation as resolved.
Show resolved Hide resolved
pass
graingert marked this conversation as resolved.
Show resolved Hide resolved
except KeyboardInterrupt:
# At this point the loop is not running thus we have to run
# do_stop() explicitly.
loop.run_sync(do_stop)
finally:
with suppress(ValueError):
except Exception as e:
if failure_type is not None:
logger.exception(f"Failed to {failure_type} worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good
# measure)
sync_sleep(cls._init_msg_interval * 3)
graingert marked this conversation as resolved.
Show resolved Hide resolved

with contextlib.ExitStack() as stack:

@stack.callback
def close_stop_q() -> None:
try:
child_stop_q.put({"op": "stop"}) # usually redundant
with suppress(ValueError):
except ValueError:
pass

try:
child_stop_q.close() # usually redundant
except ValueError:
pass
child_stop_q.join_thread()
thread.join(timeout=2)

os.environ.update(env)
dask.config.refresh()
dask.config.set(config)

from dask.multiprocessing import default_initializer

default_initializer()

if silence_logs:
logger.setLevel(silence_logs)

asyncio.run(run())


def _get_env_variables(config_key: str) -> dict[str, str]:
Expand Down
23 changes: 20 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
map_varying,
nodebug,
popen,
pristine_loop,
randominc,
save_sys_modules,
slowadd,
Expand Down Expand Up @@ -2207,8 +2206,26 @@ async def test_multi_client(s, a, b):
await asyncio.sleep(0.01)


@contextmanager
def _pristine_loop():
IOLoop.clear_instance()
IOLoop.clear_current()
loop = IOLoop()
loop.make_current()
assert IOLoop.current() is loop
try:
yield loop
finally:
try:
loop.close(all_fds=True)
except (KeyError, ValueError):
pass
IOLoop.clear_instance()
IOLoop.clear_current()


def long_running_client_connection(address):
with pristine_loop():
with _pristine_loop():
c = Client(address)
x = c.submit(lambda x: x + 1, 10)
x.result()
Expand Down Expand Up @@ -5602,7 +5619,7 @@ async def close():
async with client:
pass

with pristine_loop() as loop:
with _pristine_loop() as loop:
with pytest.warns(
DeprecationWarning,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is deprecated",
Expand Down
18 changes: 0 additions & 18 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,6 @@ async def run():
return


@contextmanager
def pristine_loop():
IOLoop.clear_instance()
IOLoop.clear_current()
loop = IOLoop()
loop.make_current()
assert IOLoop.current() is loop
try:
yield loop
finally:
try:
loop.close(all_fds=True)
except (KeyError, ValueError):
pass
IOLoop.clear_instance()
IOLoop.clear_current()


original_config = copy.deepcopy(dask.config.config)


Expand Down