Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORE: fix timeout handle for triggered post #679

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates. 2021.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -82,6 +82,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,

typedef struct ucc_tl_nccl_team {
ucc_tl_team_t super;
ucc_status_t comm_state;
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
ncclUniqueId *unique_id;
void *oob_req;
ncclComm_t nccl_comm;
Expand Down
8 changes: 6 additions & 2 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ ucc_status_t ucc_tl_nccl_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,

ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task)
{
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
ucc_status_t status = UCC_OK ;
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
ucc_tl_nccl_team_t *team = TASK_TEAM(task);
ucc_status_t status = UCC_OK;

if (ucc_unlikely(task->super.super.status != UCC_OK)) {
team->comm_state = task->super.super.status;
}
tl_info(UCC_TASK_LIB(task), "finalizing coll task %p", task);
ucc_tl_nccl_free_task(task);
return status;
Expand Down
13 changes: 10 additions & 3 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

size = UCC_TL_TEAM_SIZE(self);
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
self->comm_state = UCC_OK;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
if (!self->unique_id) {
tl_error(ctx->super.super.lib,
"failed to allocate %zd bytes for unique_id array",
Expand Down Expand Up @@ -57,7 +58,13 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t)
{
tl_info(self->super.super.context->lib, "finalizing tl team: %p", self);
if (self->nccl_comm) {
ncclCommDestroy(self->nccl_comm);
if (self->comm_state != UCC_OK) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(self->nccl_comm);
} else {
ncclCommDestroy(self->nccl_comm);
}
cudaStreamDestroy(self->stream);
}
}
Expand Down
41 changes: 10 additions & 31 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* See file LICENSE for terms.
*/

Expand Down Expand Up @@ -339,23 +339,6 @@ static void ucc_triggered_task_cb(void *task, ucc_status_t st)
ucc_triggered_task_finalize((ucc_coll_task_t*)task);
}

static ucc_status_t ucc_triggered_coll_complete(ucc_coll_task_t *parent_task, //NOLINT
ucc_coll_task_t *task)
{
ucc_trace("triggered collective complete, task %p, seq_num %u",
task, task->seq_num);
if (!(task->flags & UCC_COLL_TASK_FLAG_EXECUTOR)) {
/* need to stop and finalize executor here in case if collective itself
* doesn't need executor and executor was created as part of
* triggered post
*/
ucc_ee_executor_stop(task->executor);
ucc_ee_executor_finalize(task->executor);
task->executor = NULL;
}
return UCC_OK;
}

static ucc_status_t ucc_trigger_complete(ucc_coll_task_t *parent_task,
ucc_coll_task_t *task)
{
Expand All @@ -366,21 +349,16 @@ static ucc_status_t ucc_trigger_complete(ucc_coll_task_t *parent_task,

if (!(task->flags & UCC_COLL_TASK_FLAG_EXECUTOR)) {
task->executor = parent_task->executor;
task->flags |= (UCC_COLL_TASK_FLAG_EXECUTOR_STOP |
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY);
}

status = task->post(task);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("failed to post triggered coll, task %p, seq_num %u, %s",
task, task->seq_num, ucc_status_string(status));
return status;
}

if (task->super.status == UCC_OK) {
return ucc_triggered_coll_complete(task, task);
}
ucc_assert(task->super.status == UCC_INPROGRESS);
// TODO use CB instead of EM
return ucc_event_manager_subscribe(task, UCC_EVENT_COMPLETED, task,
ucc_triggered_coll_complete);
return status;
}

static void ucc_trigger_test(ucc_coll_task_t *task)
Expand All @@ -391,15 +369,16 @@ static void ucc_trigger_test(ucc_coll_task_t *task)
ucc_ee_executor_params_t params;

if (task->ev == NULL) {
if (task->ee->ee_type == UCC_EE_CUDA_STREAM || task->ee->ee_type == UCC_EE_ROCM_STREAM) {
if ((task->ee->ee_type == UCC_EE_CUDA_STREAM) ||
(task->ee->ee_type == UCC_EE_ROCM_STREAM)) {
/* implicit event triggered */
task->ev = (ucc_ev_t *) 0xFFFF; /* dummy event */
task->ev = (ucc_ev_t *) 0xFFFF; /* dummy event */
task->executor = NULL;
} else if (UCC_OK == ucc_ee_get_event_internal(task->ee, &ev,
&task->ee->event_in_queue)) {
ucc_trace("triggered event arrived, ev_task %p", task);
task->ev = ev;
task->ee_task = NULL;
task->ev = ev;
task->executor = NULL;
} else {
task->status = UCC_OK;
return;
Expand Down
8 changes: 4 additions & 4 deletions src/schedule/ucc_schedule.c
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ ucc_status_t ucc_coll_task_get_executor(ucc_coll_task_t *task,
return st;
}

static ucc_status_t
ucc_task_error_handler(ucc_coll_task_t *parent_task,
ucc_coll_task_t *task)
static ucc_status_t ucc_task_error_handler(ucc_coll_task_t *parent_task,
ucc_coll_task_t *task)
{
ucc_event_manager_t *em;
ucc_coll_task_t *listener;
Expand Down Expand Up @@ -197,7 +196,8 @@ ucc_status_t ucc_schedule_init(ucc_schedule_t *schedule,
return status;
}

ucc_status_t ucc_schedule_add_task(ucc_schedule_t *schedule, ucc_coll_task_t *task)
ucc_status_t ucc_schedule_add_task(ucc_schedule_t *schedule,
ucc_coll_task_t *task)
{
ucc_status_t status;

Expand Down
25 changes: 18 additions & 7 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ typedef struct ucc_event_manager {
} ucc_event_manager_t;

enum {
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
/* executor is required for collective*/
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
/* user visible task */
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
/* stop executor in task complete*/
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3)
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3),
/* destroy executor in task complete */
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4),
};

typedef struct ucc_coll_task {
Expand All @@ -80,7 +82,7 @@ typedef struct ucc_coll_task {
ucc_status_t status;
ucc_list_link_t em_list;
ucc_base_coll_args_t bargs;
ucc_base_team_t *team; //CL/TL team pointer
ucc_base_team_t *team; /* CL/TL team pointer */
ucc_schedule_t *schedule;
uint32_t flags;
ucc_coll_post_fn_t post;
Expand All @@ -91,7 +93,6 @@ typedef struct ucc_coll_task {
ucc_coll_callback_t cb;
ucc_ee_h ee;
ucc_ev_t *ev;
void *ee_task;
ucc_coll_task_t *triggered_task;
ucc_ee_executor_t *executor;
union {
Expand Down Expand Up @@ -184,7 +185,7 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
if (UCC_ERR_TIMED_OUT == status) {
char coll_str[256];
ucc_coll_str(task, coll_str, sizeof(coll_str));
ucc_warn("timeout %g sec has expired on %s",
ucc_warn("timeout %g sec. has expired on %s",
task->bargs.args.timeout, coll_str);
} else {
ucc_error("failure in task %p, %s", task,
Expand All @@ -200,10 +201,20 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
}
}

if ((task->executor) && (task->flags & UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY)) {
status = ucc_ee_executor_finalize(task->executor);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("failed to finalize executor %s",
ucc_status_string(status));
}
task->executor = NULL;
}

task->super.status = status;
if (has_cb) {
cb.cb(cb.data, status);
}

if (has_sched && status == UCC_OK) {
status = ucc_event_manager_notify(task, UCC_EVENT_COMPLETED_SCHEDULE);
}
Expand Down