Skip to content

Commit

Permalink
Cancel inflight messages on endpoint error callback (#792)
Browse files Browse the repository at this point in the history
* Cancel inflight messages on endpoint error callback

* Allow ep.recv to raise UCXCanceled

* Add test to ensure ep.recv raises UCXCanceled

* Prevent canceling messages that are already in transit

* Update _inflight_msgs_to_cancel instead of overwriting

* Rename UCXWorker function to query_total_inflight_messages_to_cancel

* Remove duplicate parametrize decorator and commented code
  • Loading branch information
pentschev authored Oct 28, 2021
1 parent addee5e commit d701a3f
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 37 deletions.
68 changes: 50 additions & 18 deletions tests/test_from_worker_address_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ async def run():
# wait for client to connect, then shutdown worker.
q1.put(address)

ep_connected = q2.get()
assert ep_connected == "connected"
ep_ready = q2.get()
assert ep_ready == "ready"

ucp.reset()

Expand Down Expand Up @@ -61,20 +61,45 @@ async def run():
if time.monotonic() - start >= 1.0:
return
else:
# Create endpoint to remote worker and inform it that connection was
# established, wait for it to shutdown and confirm, then attempt to
# send message.
# Create endpoint to remote worker, and:
#
# 1. For timeout_send:
# - inform remote worker that local endpoint is ready for remote
# shutdown;
# - wait for remote worker to shutdown and confirm;
# - attempt to send message.
#
# 2. For timeout_recv:
# - schedule ep.recv;
# - inform remote worker that local endpoint is ready for remote
# shutdown;
# - wait for it to shutdown and confirm
# - wait for recv message.
ep = await ucp.create_endpoint_from_worker_address(remote_address)

q2.put("connected")
if error_type == "timeout_send":
q2.put("ready")

remote_disconnected = q1.get()
assert remote_disconnected == "disconnected"
remote_disconnected = q1.get()
assert remote_disconnected == "disconnected"

with pytest.raises(ucp.exceptions.UCXError, match="Endpoint timeout"):
await asyncio.wait_for(
ep.send(np.zeros(10), tag=0, force_tag=True), timeout=1.0
)
with pytest.raises(ucp.exceptions.UCXError, match="Endpoint timeout"):
await asyncio.wait_for(
ep.send(np.zeros(10), tag=0, force_tag=True), timeout=1.0
)
else:
with pytest.raises(ucp.exceptions.UCXCanceled):
msg = np.empty(10)
task = asyncio.wait_for(
ep.recv(msg, tag=0, force_tag=True), timeout=3.0
)

q2.put("ready")

remote_disconnected = q1.get()
assert remote_disconnected == "disconnected"

await task

asyncio.get_event_loop().run_until_complete(run())

Expand All @@ -83,11 +108,12 @@ async def run():
ucp.get_ucx_version() < (1, 11, 0),
reason="Endpoint error handling is unreliable in UCX releases prior to 1.11.0",
)
@pytest.mark.parametrize("error_type", ["unreachable", "timeout"])
@pytest.mark.parametrize("error_type", ["unreachable", "timeout_send", "timeout_recv"])
def test_from_worker_address_error(error_type):
os.environ["UCX_WARN_UNUSED_ENV_VARS"] = "n"
# Set low UD timeout to ensure it raises as expected
os.environ["UCX_UD_TIMEOUT"] = "0.1s"
# Set low timeouts to ensure tests quickly raise as expected
os.environ["UCX_KEEPALIVE_INTERVAL"] = "100ms"
os.environ["UCX_UD_TIMEOUT"] = "100ms"

q1 = mp.Queue()
q2 = mp.Queue()
Expand All @@ -108,6 +134,12 @@ def test_from_worker_address_error(error_type):
assert not server.exitcode

if ucp.get_ucx_version() < (1, 12, 0) and client.exitcode == 1:
pytest.xfail("Requires https://github.com/openucx/ucx/pull/7527 with rc/ud.")
else:
assert not client.exitcode
if error_type == "timeout_send":
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7527 with rc/ud."
)
elif error_type == "timeout_recv":
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7531 with rc/ud."
)
assert not client.exitcode
33 changes: 22 additions & 11 deletions ucp/_libs/ucx_endpoint.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError
logger = logging.getLogger("ucx")


cdef void _cancel_inflight_msgs(UCXWorker worker, set inflight_msgs):
cdef UCXRequest req
cdef dict req_info
cdef str name
for req in list(inflight_msgs):
assert not req.closed()
req_info = <dict>req._handle.info
name = req_info["name"]
logger.debug("Future cancelling: %s" % name)
# Notice, `request_cancel()` evoke the send/recv callback functions
worker.request_cancel(req)


class UCXEndpointCloseCallback():
def __init__(self):
self._cb_func = None
Expand All @@ -34,7 +47,9 @@ class UCXEndpointCloseCallback():

cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
cdef UCXEndpoint ucx_ep = <UCXEndpoint> arg
assert ucx_ep.worker.initialized
cdef UCXWorker ucx_worker = ucx_ep.worker
cdef set inflight_msgs = ucx_ep._inflight_msgs
assert ucx_worker.initialized

cdef ucs_status_t *ep_status = <ucs_status_t *> <uintptr_t>ucx_ep._status
ep_status[0] = status
Expand All @@ -48,6 +63,11 @@ cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
ucx_ep._endpoint_close_callback.run()
logger.debug(msg)

# Schedule inflight messages to be canceled after all UCP progress
# is complete. This may happen if the user called ep.recv() but
# the remote worker errored before sending the message.
ucx_worker._inflight_msgs_to_cancel.update(inflight_msgs)


cdef (ucp_err_handler_cb_t, uintptr_t) _get_error_callback(
str tls, bint endpoint_error_handling
Expand Down Expand Up @@ -92,16 +112,7 @@ def _ucx_endpoint_finalizer(
free(<void *>status_handle_as_int)

# Cancel all inflight messages
cdef UCXRequest req
cdef dict req_info
cdef str name
for req in list(inflight_msgs):
assert not req.closed()
req_info = <dict>req._handle.info
name = req_info["name"]
logger.debug("Future cancelling: %s" % name)
# Notice, `request_cancel()` evoke the send/recv callback functions
worker.request_cancel(req)
_cancel_inflight_msgs(worker, inflight_msgs)

# Cancel waiting `am_recv` calls
cdef dict recv_wait
Expand Down
21 changes: 20 additions & 1 deletion ucp/_libs/ucx_worker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ logger = logging.getLogger("ucx")


cdef _drain_worker_tag_recv(ucp_worker_h handle):
cdef ucp_tag_message_h
cdef ucp_tag_message_h message
cdef ucp_tag_recv_info_t info
cdef ucs_status_ptr_t status
cdef void *buf
Expand Down Expand Up @@ -96,6 +96,7 @@ cdef class UCXWorker(UCXObject):
ucp_worker_h _handle
UCXContext _context
set _inflight_msgs
set _inflight_msgs_to_cancel
IF CY_UCP_AM_SUPPORTED:
dict _am_recv_pool
dict _am_recv_wait
Expand All @@ -119,6 +120,7 @@ cdef class UCXWorker(UCXObject):
status = ucp_worker_create(context._handle, &worker_params, &self._handle)
assert_ucs_status(status)
self._inflight_msgs = set()
self._inflight_msgs_to_cancel = set()

IF CY_UCP_AM_SUPPORTED:
cdef int AM_MSG_ID = 0
Expand Down Expand Up @@ -274,3 +276,20 @@ cdef class UCXWorker(UCXObject):
cdef FILE *text_fd = create_text_fd()
ucp_worker_print_info(self._handle, text_fd)
return decode_text_fd(text_fd)

def query_total_inflight_messages_to_cancel(self):
"""Query the total of inflight messages scheduled to cancel
While there are messages scheduled for canceling, we need to progress
the worker. Therefore, this can be used to query if there are still any
such messages and progress while the result is larger than 0.
Returns
-------
total: The total number of inflight messages scheduled to cancel.
"""
len_inflight_msgs_to_cancel = len(self._inflight_msgs_to_cancel)
if len_inflight_msgs_to_cancel > 0:
_cancel_inflight_msgs(self, self._inflight_msgs_to_cancel)
self._inflight_msgs_to_cancel = set()
return len_inflight_msgs_to_cancel
7 changes: 7 additions & 0 deletions ucp/continuous_ucx_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ async def _arm_worker(self):
if worker is None or not worker.initialized:
return
worker.progress()

# Cancel inflight messages that couldn't be completed. This may
# happen if the user called ep.recv() but the remote worker
# errored before sending the message.
if worker.query_total_inflight_messages_to_cancel() > 0:
worker.progress()

del worker

# This IO task returns when all non-IO tasks are finished.
Expand Down
8 changes: 1 addition & 7 deletions ucp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,13 +722,7 @@ async def recv(self, buffer, tag=None, force_tag=False):
logger.debug(log)
self._recv_count += 1

try:
ret = await comm.tag_recv(self._ep, buffer, nbytes, tag, name=log)
except UCXCanceled as e:
# If self._ep has already been closed and destroyed, we reraise the
# UCXCanceled exception.
if self._ep is None:
raise e
ret = await comm.tag_recv(self._ep, buffer, nbytes, tag, name=log)

self._finished_recv_count += 1
if (
Expand Down

0 comments on commit d701a3f

Please sign in to comment.