Skip to content

Commit

Permalink
if there's an exception in the Client async context manager body then…
Browse files Browse the repository at this point in the history
… close fast (#6920)

Co-authored-by: fjetter <fjetter@users.noreply.github.com>
  • Loading branch information
graingert and fjetter authored Nov 2, 2022
1 parent 5a14053 commit c137ac0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
17 changes: 15 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,12 @@ async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self._close()
await self._close(
# if we're handling an exception, we assume that it's more
# important to deliver that exception than shutdown gracefully.
fast=exc_type
is not None
)

def __exit__(self, exc_type, exc_value, traceback):
self.close()
Expand Down Expand Up @@ -1568,7 +1573,15 @@ def _handle_error(self, exception=None):
logger.exception(exception)

async def _close(self, fast=False):
"""Send close signal and wait until scheduler completes"""
"""
Send close signal and wait until scheduler completes
If fast is True, the client will close forcefully, by cancelling tasks
the background _handle_report_task.
"""
# TODO: aclose more forcefully by aborting the RPC and cancelling all
# background tasks.
# see https://trio.readthedocs.io/en/stable/reference-io.html#trio.aclose_forcefully
if self.status == "closed":
return

Expand Down
19 changes: 19 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from threading import Semaphore
from time import sleep
from typing import Any
from unittest import mock

import psutil
import pytest
Expand Down Expand Up @@ -7775,6 +7776,24 @@ def __init__(self, *args, **kwargs):
]


@gen_cluster(client=False, nthreads=[])
async def test_fast_close_on_aexit_failure(s):
class MyException(Exception):
pass

c = Client(s.address, asynchronous=True)
with mock.patch.object(c, "_close", wraps=c._close) as _close_proxy:
with pytest.raises(MyException):
async with c:
start = time()
raise MyException
stop = time()

assert _close_proxy.mock_calls == [mock.call(fast=True)]
assert c.status == "closed"
assert (stop - start) < 2


@gen_cluster(client=True, nthreads=[])
async def test_wait_for_workers_no_default(c, s):
with pytest.warns(
Expand Down

0 comments on commit c137ac0

Please sign in to comment.