From 39e24c744a5e621e0b2e6c7f9941cfcd23c32e38 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Wed, 9 Nov 2022 15:13:08 +0400 Subject: [PATCH] CORE: fix timeout handle for triggered post --- src/components/tl/nccl/tl_nccl.h | 3 +- src/components/tl/nccl/tl_nccl_coll.c | 8 ++++-- src/components/tl/nccl/tl_nccl_team.c | 13 +++++++-- src/core/ucc_coll.c | 41 +++++++-------------------- src/schedule/ucc_schedule.c | 8 +++--- src/schedule/ucc_schedule.h | 25 +++++++++++----- 6 files changed, 50 insertions(+), 48 deletions(-) diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 998cffadde..6f3daee474 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Facebook, Inc. and its affiliates. 2021. * * See file LICENSE for terms. @@ -82,6 +82,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *, typedef struct ucc_tl_nccl_team { ucc_tl_team_t super; + ucc_status_t comm_state; ncclUniqueId *unique_id; void *oob_req; ncclComm_t nccl_comm; diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index cd69bbec2f..fbeb366262 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -149,9 +149,13 @@ ucc_status_t ucc_tl_nccl_triggered_post(ucc_ee_h ee, ucc_ev_t *ev, ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task) { - ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); - ucc_status_t status = UCC_OK ; + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_tl_nccl_team_t *team = TASK_TEAM(task); + ucc_status_t status = UCC_OK; + if (ucc_unlikely(task->super.super.status != UCC_OK)) { + team->comm_state = task->super.super.status; + } tl_info(UCC_TASK_LIB(task), "finalizing coll task %p", task); ucc_tl_nccl_free_task(task); return status; diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index f19f5411c9..8613dbae11 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -22,8 +22,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params); size = UCC_TL_TEAM_SIZE(self); - self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1), - "tl_nccl_unique_id"); + self->comm_state = UCC_OK; + self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1), + "tl_nccl_unique_id"); if (!self->unique_id) { tl_error(ctx->super.super.lib, "failed to allocate %zd bytes for unique_id array", @@ -57,7 +58,13 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t) { tl_info(self->super.super.context->lib, "finalizing tl team: %p", self); if (self->nccl_comm) { - ncclCommDestroy(self->nccl_comm); + if (self->comm_state != UCC_OK) { + /* if communication error was detected ncclCommAbort should be used + since ncclCommDestroy could block */ + ncclCommAbort(self->nccl_comm); + } else { + ncclCommDestroy(self->nccl_comm); + } cudaStreamDestroy(self->stream); } } diff --git a/src/core/ucc_coll.c b/src/core/ucc_coll.c index f31d4d22aa..e51fd7536d 100644 --- a/src/core/ucc_coll.c +++ b/src/core/ucc_coll.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * See file LICENSE for terms. */ @@ -339,23 +339,6 @@ static void ucc_triggered_task_cb(void *task, ucc_status_t st) ucc_triggered_task_finalize((ucc_coll_task_t*)task); } -static ucc_status_t ucc_triggered_coll_complete(ucc_coll_task_t *parent_task, //NOLINT - ucc_coll_task_t *task) -{ - ucc_trace("triggered collective complete, task %p, seq_num %u", - task, task->seq_num); - if (!(task->flags & UCC_COLL_TASK_FLAG_EXECUTOR)) { - /* need to stop and finalize executor here in case if collective itself - * doesn't need executor and executor was created as part of - * triggered post - */ - ucc_ee_executor_stop(task->executor); - ucc_ee_executor_finalize(task->executor); - task->executor = NULL; - } - return UCC_OK; -} - static ucc_status_t ucc_trigger_complete(ucc_coll_task_t *parent_task, ucc_coll_task_t *task) { @@ -366,21 +349,16 @@ static ucc_status_t ucc_trigger_complete(ucc_coll_task_t *parent_task, if (!(task->flags & UCC_COLL_TASK_FLAG_EXECUTOR)) { task->executor = parent_task->executor; + task->flags |= (UCC_COLL_TASK_FLAG_EXECUTOR_STOP | + UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY); } + status = task->post(task); if (ucc_unlikely(status != UCC_OK)) { ucc_error("failed to post triggered coll, task %p, seq_num %u, %s", task, task->seq_num, ucc_status_string(status)); - return status; - } - - if (task->super.status == UCC_OK) { - return ucc_triggered_coll_complete(task, task); } - ucc_assert(task->super.status == UCC_INPROGRESS); - // TODO use CB instead of EM - return ucc_event_manager_subscribe(task, UCC_EVENT_COMPLETED, task, - ucc_triggered_coll_complete); + return status; } static void ucc_trigger_test(ucc_coll_task_t *task) @@ -391,15 +369,16 @@ static void ucc_trigger_test(ucc_coll_task_t *task) ucc_ee_executor_params_t params; if (task->ev == NULL) { - if (task->ee->ee_type == UCC_EE_CUDA_STREAM || task->ee->ee_type == UCC_EE_ROCM_STREAM) { + if ((task->ee->ee_type == UCC_EE_CUDA_STREAM) || + (task->ee->ee_type == UCC_EE_ROCM_STREAM)) { /* implicit event triggered */ - task->ev = (ucc_ev_t *) 0xFFFF; /* dummy event */ + task->ev = (ucc_ev_t *) 0xFFFF; /* dummy event */ task->executor = NULL; } else if (UCC_OK == ucc_ee_get_event_internal(task->ee, &ev, &task->ee->event_in_queue)) { ucc_trace("triggered event arrived, ev_task %p", task); - task->ev = ev; - task->ee_task = NULL; + task->ev = ev; + task->executor = NULL; } else { task->status = UCC_OK; return; diff --git a/src/schedule/ucc_schedule.c b/src/schedule/ucc_schedule.c index d9fbc3fb20..121fcd9f9b 100644 --- a/src/schedule/ucc_schedule.c +++ b/src/schedule/ucc_schedule.c @@ -121,9 +121,8 @@ ucc_status_t ucc_coll_task_get_executor(ucc_coll_task_t *task, return st; } -static ucc_status_t -ucc_task_error_handler(ucc_coll_task_t *parent_task, - ucc_coll_task_t *task) +static ucc_status_t ucc_task_error_handler(ucc_coll_task_t *parent_task, + ucc_coll_task_t *task) { ucc_event_manager_t *em; ucc_coll_task_t *listener; @@ -197,7 +196,8 @@ ucc_status_t ucc_schedule_init(ucc_schedule_t *schedule, return status; } -ucc_status_t ucc_schedule_add_task(ucc_schedule_t *schedule, ucc_coll_task_t *task) +ucc_status_t ucc_schedule_add_task(ucc_schedule_t *schedule, + ucc_coll_task_t *task) { ucc_status_t status; diff --git a/src/schedule/ucc_schedule.h b/src/schedule/ucc_schedule.h index 8f999bfe3a..2258249c33 100644 --- a/src/schedule/ucc_schedule.h +++ b/src/schedule/ucc_schedule.h @@ -61,13 +61,15 @@ typedef struct ucc_event_manager { } ucc_event_manager_t; enum { - UCC_COLL_TASK_FLAG_CB = UCC_BIT(0), + UCC_COLL_TASK_FLAG_CB = UCC_BIT(0), /* executor is required for collective*/ - UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1), + UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1), /* user visible task */ - UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2), + UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2), /* stop executor in task complete*/ - UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3) + UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3), + /* destroy executor in task complete */ + UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4), }; typedef struct ucc_coll_task { @@ -80,7 +82,7 @@ typedef struct ucc_coll_task { ucc_status_t status; ucc_list_link_t em_list; ucc_base_coll_args_t bargs; - ucc_base_team_t *team; //CL/TL team pointer + ucc_base_team_t *team; /* CL/TL team pointer */ ucc_schedule_t *schedule; uint32_t flags; ucc_coll_post_fn_t post; @@ -91,7 +93,6 @@ typedef struct ucc_coll_task { ucc_coll_callback_t cb; ucc_ee_h ee; ucc_ev_t *ev; - void *ee_task; ucc_coll_task_t *triggered_task; ucc_ee_executor_t *executor; union { @@ -184,7 +185,7 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task) if (UCC_ERR_TIMED_OUT == status) { char coll_str[256]; ucc_coll_str(task, coll_str, sizeof(coll_str)); - ucc_warn("timeout %g sec has expired on %s", + ucc_warn("timeout %g sec. has expired on %s", task->bargs.args.timeout, coll_str); } else { ucc_error("failure in task %p, %s", task, @@ -200,10 +201,20 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task) } } + if ((task->executor) && (task->flags & UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY)) { + status = ucc_ee_executor_finalize(task->executor); + if (ucc_unlikely(status != UCC_OK)) { + ucc_error("failed to finalize executor %s", + ucc_status_string(status)); + } + task->executor = NULL; + } + task->super.status = status; if (has_cb) { cb.cb(cb.data, status); } + if (has_sched && status == UCC_OK) { status = ucc_event_manager_notify(task, UCC_EVENT_COMPLETED_SCHEDULE); }