Skip to content

Commit

Permalink
CORE: fix timeout handle for triggered post
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 18, 2022
1 parent 387c264 commit 39e24c7
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 48 deletions.
3 changes: 2 additions & 1 deletion src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates. 2021.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -82,6 +82,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,

typedef struct ucc_tl_nccl_team {
ucc_tl_team_t super;
ucc_status_t comm_state;
ncclUniqueId *unique_id;
void *oob_req;
ncclComm_t nccl_comm;
Expand Down
8 changes: 6 additions & 2 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ ucc_status_t ucc_tl_nccl_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,

ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task)
{
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
ucc_status_t status = UCC_OK ;
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
ucc_tl_nccl_team_t *team = TASK_TEAM(task);
ucc_status_t status = UCC_OK;

if (ucc_unlikely(task->super.super.status != UCC_OK)) {
team->comm_state = task->super.super.status;
}
tl_info(UCC_TASK_LIB(task), "finalizing coll task %p", task);
ucc_tl_nccl_free_task(task);
return status;
Expand Down
13 changes: 10 additions & 3 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

size = UCC_TL_TEAM_SIZE(self);
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
self->comm_state = UCC_OK;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
if (!self->unique_id) {
tl_error(ctx->super.super.lib,
"failed to allocate %zd bytes for unique_id array",
Expand Down Expand Up @@ -57,7 +58,13 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t)
{
tl_info(self->super.super.context->lib, "finalizing tl team: %p", self);
if (self->nccl_comm) {
ncclCommDestroy(self->nccl_comm);
if (self->comm_state != UCC_OK) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(self->nccl_comm);
} else {
ncclCommDestroy(self->nccl_comm);
}
cudaStreamDestroy(self->stream);
}
}
Expand Down
41 changes: 10 additions & 31 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* See file LICENSE for terms.
*/

Expand Down Expand Up @@ -339,23 +339,6 @@ static void ucc_triggered_task_cb(void *task, ucc_status_t st)
ucc_triggered_task_finalize((ucc_coll_task_t*)task);
}

static ucc_status_t ucc_triggered_coll_complete(ucc_coll_task_t *parent_task, //NOLINT
ucc_coll_task_t *task)
{
ucc_trace("triggered collective complete, task %p, seq_num %u",
task, task->seq_num);
if (!(task->flags & UCC_COLL_TASK_FLAG_EXECUTOR)) {
/* need to stop and finalize executor here in case if collective itself
* doesn't need executor and executor was created as part of
* triggered post
*/
ucc_ee_executor_stop(task->executor);
ucc_ee_executor_finalize(task->executor);
task->executor = NULL;
}
return UCC_OK;
}

static ucc_status_t ucc_trigger_complete(ucc_coll_task_t *parent_task,
ucc_coll_task_t *task)
{
Expand All @@ -366,21 +349,16 @@ static ucc_status_t ucc_trigger_complete(ucc_coll_task_t *parent_task,

if (!(task->flags & UCC_COLL_TASK_FLAG_EXECUTOR)) {
task->executor = parent_task->executor;
task->flags |= (UCC_COLL_TASK_FLAG_EXECUTOR_STOP |
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY);
}

status = task->post(task);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("failed to post triggered coll, task %p, seq_num %u, %s",
task, task->seq_num, ucc_status_string(status));
return status;
}

if (task->super.status == UCC_OK) {
return ucc_triggered_coll_complete(task, task);
}
ucc_assert(task->super.status == UCC_INPROGRESS);
// TODO use CB instead of EM
return ucc_event_manager_subscribe(task, UCC_EVENT_COMPLETED, task,
ucc_triggered_coll_complete);
return status;
}

static void ucc_trigger_test(ucc_coll_task_t *task)
Expand All @@ -391,15 +369,16 @@ static void ucc_trigger_test(ucc_coll_task_t *task)
ucc_ee_executor_params_t params;

if (task->ev == NULL) {
if (task->ee->ee_type == UCC_EE_CUDA_STREAM || task->ee->ee_type == UCC_EE_ROCM_STREAM) {
if ((task->ee->ee_type == UCC_EE_CUDA_STREAM) ||
(task->ee->ee_type == UCC_EE_ROCM_STREAM)) {
/* implicit event triggered */
task->ev = (ucc_ev_t *) 0xFFFF; /* dummy event */
task->ev = (ucc_ev_t *) 0xFFFF; /* dummy event */
task->executor = NULL;
} else if (UCC_OK == ucc_ee_get_event_internal(task->ee, &ev,
&task->ee->event_in_queue)) {
ucc_trace("triggered event arrived, ev_task %p", task);
task->ev = ev;
task->ee_task = NULL;
task->ev = ev;
task->executor = NULL;
} else {
task->status = UCC_OK;
return;
Expand Down
8 changes: 4 additions & 4 deletions src/schedule/ucc_schedule.c
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ ucc_status_t ucc_coll_task_get_executor(ucc_coll_task_t *task,
return st;
}

static ucc_status_t
ucc_task_error_handler(ucc_coll_task_t *parent_task,
ucc_coll_task_t *task)
static ucc_status_t ucc_task_error_handler(ucc_coll_task_t *parent_task,
ucc_coll_task_t *task)
{
ucc_event_manager_t *em;
ucc_coll_task_t *listener;
Expand Down Expand Up @@ -197,7 +196,8 @@ ucc_status_t ucc_schedule_init(ucc_schedule_t *schedule,
return status;
}

ucc_status_t ucc_schedule_add_task(ucc_schedule_t *schedule, ucc_coll_task_t *task)
ucc_status_t ucc_schedule_add_task(ucc_schedule_t *schedule,
ucc_coll_task_t *task)
{
ucc_status_t status;

Expand Down
25 changes: 18 additions & 7 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ typedef struct ucc_event_manager {
} ucc_event_manager_t;

enum {
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
/* executor is required for collective*/
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
/* user visible task */
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
/* stop executor in task complete*/
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3)
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3),
/* destroy executor in task complete */
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4),
};

typedef struct ucc_coll_task {
Expand All @@ -80,7 +82,7 @@ typedef struct ucc_coll_task {
ucc_status_t status;
ucc_list_link_t em_list;
ucc_base_coll_args_t bargs;
ucc_base_team_t *team; //CL/TL team pointer
ucc_base_team_t *team; /* CL/TL team pointer */
ucc_schedule_t *schedule;
uint32_t flags;
ucc_coll_post_fn_t post;
Expand All @@ -91,7 +93,6 @@ typedef struct ucc_coll_task {
ucc_coll_callback_t cb;
ucc_ee_h ee;
ucc_ev_t *ev;
void *ee_task;
ucc_coll_task_t *triggered_task;
ucc_ee_executor_t *executor;
union {
Expand Down Expand Up @@ -184,7 +185,7 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
if (UCC_ERR_TIMED_OUT == status) {
char coll_str[256];
ucc_coll_str(task, coll_str, sizeof(coll_str));
ucc_warn("timeout %g sec has expired on %s",
ucc_warn("timeout %g sec. has expired on %s",
task->bargs.args.timeout, coll_str);
} else {
ucc_error("failure in task %p, %s", task,
Expand All @@ -200,10 +201,20 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
}
}

if ((task->executor) && (task->flags & UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY)) {
status = ucc_ee_executor_finalize(task->executor);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("failed to finalize executor %s",
ucc_status_string(status));
}
task->executor = NULL;
}

task->super.status = status;
if (has_cb) {
cb.cb(cb.data, status);
}

if (has_sched && status == UCC_OK) {
status = ucc_event_manager_notify(task, UCC_EVENT_COMPLETED_SCHEDULE);
}
Expand Down

0 comments on commit 39e24c7

Please sign in to comment.