Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: revise team and ctx init #815

Merged
merged 14 commits into from
Aug 22, 2023
207 changes: 85 additions & 122 deletions src/components/tl/mlx5/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,17 @@ static ucc_status_t build_rank_map(ucc_tl_mlx5_alltoall_t *a2a,
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
ucc_status_t ucc_tl_mlx5_team_init_alltoall(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = NULL;
ucc_tl_mlx5_alltoall_t *a2a;
ucc_sbgp_t *node, *net;
size_t storage_size;
int i, j, node_size, ppn, team_size, nnodes;
ucc_topo_t *topo;
ucc_status_t status;

a2a = ucc_calloc(1, sizeof(*a2a), "mlx5_a2a");
if (!a2a) {
return UCC_ERR_NO_MEMORY;
}
team->a2a = NULL;
team->dm_ptr = NULL;
team->a2a_status.local = UCC_OK;

topo = team->topo;
node = ucc_topo_get_sbgp(topo, UCC_SBGP_NODE);
Expand All @@ -92,27 +89,22 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
"disabling mlx5 a2a for team with non-uniform ppn, "
"min_ppn %d, max_ppn %d",
ucc_topo_min_ppn(topo), ucc_topo_max_ppn(topo));
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}
ppn = ucc_topo_max_ppn(topo);

if (net->status == UCC_SBGP_NOT_EXISTS) {
tl_debug(ctx->super.super.lib,
"disabling mlx5 a2a for single node team");
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}

if (nnodes == team_size) {
tl_debug(ctx->super.super.lib,
"disabling mlx5 a2a for ppn=1 case, not supported so far");
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}

a2a->node_size = node_size;
ucc_assert(team_size == ppn * nnodes);

for (i = 0; i < nnodes; i++) {
for (j = 1; j < ppn; j++) {
Expand All @@ -121,38 +113,75 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
tl_debug(ctx->super.super.lib,
"disabling mlx5 a2a for team with non contiguous "
"ranks-per-node placement");
status = UCC_ERR_NOT_SUPPORTED;
goto err;
goto non_fatal_error;
}
}
}

a2a->pd = ctx->shared_pd;
a2a->ctx = ctx->shared_ctx;
a2a->ib_port = ctx->ib_port;
a2a->node.sbgp = node;
a2a->net.sbgp = net;
a2a->node.asr_rank = MLX5_ASR_RANK;
a2a->num_dci_qps = UCC_TL_MLX5_TEAM_LIB(team)->cfg.num_dci_qps;
a2a->sequence_number = 1;
a2a->net.ctrl_mr = NULL;
a2a->net.remote_ctrl = NULL;
a2a->net.rank_map = NULL;
a2a->max_msg_size = MAX_MSG_SIZE;
a2a->max_num_of_columns =
team->a2a = ucc_calloc(1, sizeof(*team->a2a), "mlx5_a2a");
if (!team->a2a) {
return UCC_ERR_NO_MEMORY;
}

a2a = team->a2a;
a2a->node_size = node_size;
a2a->pd = ctx->shared_pd;
a2a->ctx = ctx->shared_ctx;
a2a->ib_port = ctx->ib_port;
a2a->node.sbgp = node;
a2a->net.sbgp = net;
a2a->node.asr_rank = MLX5_ASR_RANK;
a2a->num_dci_qps = UCC_TL_MLX5_TEAM_LIB(team)->cfg.num_dci_qps;
a2a->sequence_number = 1;
a2a->net.atomic.counters = NULL;
a2a->net.ctrl_mr = NULL;
a2a->net.remote_ctrl = NULL;
a2a->net.rank_map = NULL;
a2a->max_msg_size = MAX_MSG_SIZE;
a2a->max_num_of_columns =
ucc_div_round_up(node->group_size, 2 /* todo: there can be an estimation of
minimal possible block size */);

ucc_assert(a2a->net.sbgp->status == UCC_SBGP_ENABLED ||
node->group_rank != 0);

if (a2a->node.asr_rank == node->group_rank) {
team->a2a_status.local = ucc_tl_mlx5_dm_init(team);
if (UCC_OK != team->a2a_status.local) {
tl_debug(UCC_TL_TEAM_LIB(team), "failed to init device memory");
}
}

return UCC_OK;

non_fatal_error:
team->a2a_status.local = UCC_ERR_NOT_SUPPORTED;
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_team_test_alltoall_start(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
size_t storage_size;

if (team->a2a_status.global != UCC_OK) {
tl_debug(ctx->super.super.lib, "global status in error state: %s",
ucc_status_string(team->a2a_status.global));

ucc_tl_mlx5_dm_cleanup(team);
if (a2a) {
ucc_free(a2a);
team->a2a = NULL;
}
ucc_tl_mlx5_topo_cleanup(team);
return team->a2a_status.global;
}

if (a2a->node.asr_rank == a2a->node.sbgp->group_rank) {
a2a->net.net_size = a2a->net.sbgp->group_size;
storage_size = OP_SEGMENT_SIZE(a2a) * MAX_OUTSTANDING_OPS;
a2a->bcast_data.shmid =
shmget(IPC_PRIVATE, storage_size, IPC_CREAT | 0600);
if (a2a->bcast_data.shmid == -1) {
tl_error(ctx->super.super.lib,
tl_debug(ctx->super.super.lib,
"failed to allocate sysv shm segment for %zd bytes",
storage_size);
} else {
Expand All @@ -166,72 +195,10 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
a2a->state = TL_MLX5_ALLTOALL_STATE_SHMID;

team->a2a = a2a;
return ucc_service_bcast(UCC_TL_CORE_TEAM(team), &a2a->bcast_data,
sizeof(ucc_tl_mlx5_a2a_bcast_data_t),
a2a->node.asr_rank, ucc_sbgp_to_subset(node),
&team->scoll_req);
err:
if (a2a) {
ucc_free(a2a);
}
return status;
}

static void ucc_tl_mlx5_alltoall_atomic_free(ucc_tl_mlx5_alltoall_t *a2a)
{
ibv_dereg_mr(a2a->net.atomic.mr);
#if ATOMIC_IN_MEMIC
ibv_free_dm(a2a->net.atomic.counters);
#else
ucc_free(a2a->net.atomic.counters);
#endif
}

static ucc_status_t ucc_tl_mlx5_alltoall_atomic_alloc(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
size_t size;

size = sizeof(*a2a->net.atomic.counters) * MAX_OUTSTANDING_OPS;
#if ATOMIC_IN_MEMIC
struct ibv_alloc_dm_attr dm_attr;
memset(&dm_attr, 0, sizeof(dm_attr));
dm_attr.length = size;
a2a->net.atomic.counters = ibv_alloc_dm(ctx->shared_ctx, &dm_attr);
#else
a2a->net.atomic.counters = ucc_malloc(size, "atomic");
#endif

if (!a2a->net.atomic.counters) {
tl_debug(UCC_TL_TEAM_LIB(team),
"failed to allocate %zd bytes for atomic counters array",
size);
return UCC_ERR_NO_MEMORY;
}
#if ATOMIC_IN_MEMIC
a2a->net.atomic.mr =
ibv_reg_dm_mr(ctx->shared_pd, a2a->net.atomic.counters, 0, size,
IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_LOCAL_WRITE |
IBV_ACCESS_ZERO_BASED);

#else
a2a->net.atomic.mr =
ibv_reg_mr(ctx->shared_pd, a2a->net.atomic.counters, size,
IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_LOCAL_WRITE);
#endif

if (!a2a->net.atomic.mr) {
tl_error(UCC_TL_TEAM_LIB(team),
"failed to register atomic couters array");
#if ATOMIC_IN_MEMIC
ibv_free_dm(a2a->net.atomic.counters);
#else
ucc_free(a2a->net.atomic.counters);
#endif
return UCC_ERR_NO_MESSAGE;
}
return UCC_OK;
return ucc_service_bcast(
UCC_TL_CORE_TEAM(team), &a2a->bcast_data,
sizeof(ucc_tl_mlx5_a2a_bcast_data_t), a2a->node.asr_rank,
ucc_sbgp_to_subset(a2a->node.sbgp), &team->scoll_req);
}

static void ucc_tl_mlx5_alltoall_barrier_free(ucc_tl_mlx5_alltoall_t *a2a)
Expand Down Expand Up @@ -270,26 +237,29 @@ static ucc_status_t ucc_tl_mlx5_alltoall_barrier_alloc(ucc_tl_mlx5_team_t *team)
return UCC_OK;
}

ucc_status_t
ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
ucc_status_t ucc_tl_mlx5_team_test_alltoall_progress(ucc_tl_mlx5_team_t *team)
{
ucc_tl_mlx5_team_t *team = ucc_derived_of(tl_team,
ucc_tl_mlx5_team_t);
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
ucc_rank_t node_size = a2a->node.sbgp->group_size;
ucc_rank_t node_rank = a2a->node.sbgp->group_rank;
ucc_base_lib_t *lib = UCC_TL_TEAM_LIB(team);
size_t op_seg_size = OP_SEGMENT_SIZE(a2a);
int i = 0;
net_exchange_t *local_data = NULL;
ucc_rank_t node_size, node_rank;
ucc_status_t status;
ucc_tl_mlx5_alltoall_op_t *op;
int j, asr_cq_size, net_size, ret;
struct ibv_port_attr port_attr;
size_t local_data_size, umr_buf_size;
size_t op_seg_size, local_data_size, umr_buf_size;
net_exchange_t *global_data, *remote_data;

if (team->a2a_status.local < 0) {
return team->a2a_status.local;
}

node_size = a2a->node.sbgp->group_size;
node_rank = a2a->node.sbgp->group_rank;
op_seg_size = OP_SEGMENT_SIZE(a2a);

switch (a2a->state) {
case TL_MLX5_ALLTOALL_STATE_SHMID:
status = ucc_service_coll_test(team->scoll_req);
Expand Down Expand Up @@ -335,11 +305,6 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
return UCC_OK;
}

status = ucc_tl_mlx5_alltoall_atomic_alloc(team);
if (UCC_OK != status) {
goto err_atomic;
}

status = ucc_tl_mlx5_alltoall_barrier_alloc(team);
if (UCC_OK != status) {
goto err_barrier;
Expand Down Expand Up @@ -518,12 +483,12 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
a2a->state = TL_MLX5_ALLTOALL_STATE_EXCHANGE_PROGRESS;

case TL_MLX5_ALLTOALL_STATE_EXCHANGE_PROGRESS:
status = ucc_service_coll_test(tl_team->scoll_req);
status = ucc_service_coll_test(team->scoll_req);
if (status < 0) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
tl_error(UCC_TL_TEAM_LIB(team),
"failure during service coll exchange: %s",
ucc_status_string(status));
ucc_service_coll_finalize(tl_team->scoll_req);
ucc_service_coll_finalize(team->scoll_req);
goto err_service_allgather_progress;
}
if (UCC_INPROGRESS == status) {
Expand All @@ -534,7 +499,7 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)

case TL_MLX5_ALLTOALL_STATE_EXCHANGE_DONE:
local_data = team->scoll_req->data;
ucc_service_coll_finalize(tl_team->scoll_req);
ucc_service_coll_finalize(team->scoll_req);

net_size = a2a->net.net_size;
local_data_size = sizeof(net_exchange_t);
Expand Down Expand Up @@ -691,8 +656,6 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
err_blocks_sent:
ucc_tl_mlx5_alltoall_barrier_free(a2a);
err_barrier:
ucc_tl_mlx5_alltoall_atomic_free(a2a);
err_atomic:
return status;
}

Expand Down Expand Up @@ -721,6 +684,7 @@ void ucc_tl_mlx5_alltoall_cleanup(ucc_tl_mlx5_team_t *team)
for (i = 0; i < a2a->num_dci_qps; i++) {
ibv_destroy_qp(a2a->net.dcis[i].dci_qp);
}
ucc_free(a2a->net.dcis);
ibv_destroy_qp(a2a->net.dct_qp);
ibv_destroy_srq(a2a->net.srq);
for (i = 0; i < a2a->net.net_size; i++) {
Expand Down Expand Up @@ -753,7 +717,6 @@ void ucc_tl_mlx5_alltoall_cleanup(ucc_tl_mlx5_team_t *team)

ucc_free(a2a->net.blocks_sent);
ucc_tl_mlx5_alltoall_barrier_free(a2a);
ucc_tl_mlx5_alltoall_atomic_free(a2a);
}
ucc_free(a2a->net.dcis);
ucc_free(a2a);
}
7 changes: 5 additions & 2 deletions src/components/tl/mlx5/alltoall/alltoall.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "tl_mlx5.h"
#include "tl_mlx5_ib.h"
#include "tl_mlx5_dm.h"

#define SEQ_INDEX(_seq_num) ((_seq_num) % MAX_OUTSTANDING_OPS)

Expand Down Expand Up @@ -136,8 +137,10 @@ typedef struct ucc_tl_mlx5_alltoall {
ucc_tl_mlx5_a2a_bcast_data_t bcast_data;
} ucc_tl_mlx5_alltoall_t;

ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *team);
void ucc_tl_mlx5_topo_cleanup(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_init_alltoall(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_test_alltoall_start(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_team_test_alltoall_progress(ucc_tl_mlx5_team_t *team);
ucc_status_t ucc_tl_mlx5_alltoall_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
Expand Down
8 changes: 7 additions & 1 deletion src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ typedef struct ucc_tl_mlx5_context {
ucc_rcache_t *rcache;
int is_imported;
int ib_port;
int sock;
ucc_mpool_t req_mp;
ucc_tl_mlx5_mcast_context_t mcast;
} ucc_tl_mlx5_context_t;
Expand All @@ -108,15 +109,20 @@ typedef enum
TL_MLX5_TEAM_STATE_ALLTOALL_POSTED
} ucc_tl_mlx5_team_state_t;

typedef struct ucc_tl_mlx5_team_status {
ucc_status_t local;
ucc_status_t global;
} ucc_tl_mlx5_team_status_t;

typedef struct ucc_tl_mlx5_team {
ucc_tl_team_t super;
ucc_status_t status[2];
ucc_service_coll_req_t *scoll_req;
ucc_tl_mlx5_team_state_t state;
void *dm_offset;
ucc_mpool_t dm_pool;
struct ibv_dm *dm_ptr;
struct ibv_mr *dm_mr;
ucc_tl_mlx5_team_status_t a2a_status;
ucc_tl_mlx5_alltoall_t *a2a;
ucc_topo_t *topo;
ucc_ep_map_t ctx_map;
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

*task_h = &(task->super);

tl_debug(UCC_TASK_LIB(task), "init coll task %p", task);
Expand Down
Loading