Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Jul 13, 2023
1 parent b8344c6 commit 9838d65
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ ucc_tl_mlx5_poll_free_op_slot_start(ucc_coll_task_t *coll_task)
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int seq_index = task->alltoall.seq_index;

if (a2a->op_busy[seq_index] && !task->alltoall.started) {
tl_debug(
UCC_TL_TEAM_LIB(team),
"Operation num %d must wait for previous outstanding to complete",
if (!a2a->op_busy[seq_index]) {
tl_debug(UCC_TL_TEAM_LIB(team), "Operation num %d started",
task->alltoall.seq_num);
a2a->op_busy[seq_index] = 1;
return ucc_task_complete(coll_task);
}

tl_debug(
UCC_TL_TEAM_LIB(team),
"Operation num %d must wait for previous outstanding to complete",
task->alltoall.seq_num);
coll_task->status = UCC_INPROGRESS;
coll_task->super.status = UCC_INPROGRESS;
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task);
Expand All @@ -42,15 +46,16 @@ void ucc_tl_mlx5_poll_free_op_slot_progress(ucc_coll_task_t *coll_task)
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int seq_index = task->alltoall.seq_index;

if (a2a->op_busy[seq_index] && !task->alltoall.started) {
coll_task->status = UCC_INPROGRESS;
if (a2a->op_busy[seq_index]) {
coll_task->status = UCC_INPROGRESS;
coll_task->super.status = UCC_INPROGRESS;
return;
} //wait for slot to be open
}
a2a->op_busy[seq_index] = 1;
task->alltoall.started = 1;
coll_task->status = UCC_OK;
coll_task->super.status = UCC_OK;
tl_debug(UCC_TL_TEAM_LIB(team), "Operation num %d started",
task->alltoall.seq_num);
task->alltoall.seq_num);
}

static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib)
Expand Down Expand Up @@ -195,6 +200,7 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task)
return ucc_task_complete(coll_task);
}

ucc_assert(team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank);
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task);
return UCC_OK;
}
Expand Down Expand Up @@ -328,7 +334,7 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task)
for (i = 0; i < a2a->net.net_size; i++) {
task->alltoall.op->blocks_sent[i] = 0;
if (i == a2a->net.sbgp->group_rank) {
tl_mlx5_barrier_flag_set(task, i);
tl_mlx5_barrier_flag_set(task, i); // per ops ?
continue;
}

Expand Down Expand Up @@ -832,7 +838,7 @@ UCC_TL_MLX5_PROFILE_FUNC(ucc_status_t, ucc_tl_mlx5_alltoall_init,
task->alltoall.msg_size = msg_size;

tl_trace(UCC_TL_TEAM_LIB(tl_team), "Seq num is %d", task->alltoall.seq_num);
a2a->sequence_number += 1;
a2a->sequence_number++;

block_size = a2a->requested_block_size ? a2a->requested_block_size
: get_block_size(task);
Expand Down

0 comments on commit 9838d65

Please sign in to comment.