Skip to content

Commit

Permalink
TL/MLX5: revise team and ctx init
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Jul 31, 2023
1 parent 6bdb758 commit 1575241
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 57 deletions.
21 changes: 17 additions & 4 deletions src/components/tl/mlx5/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ ucc_status_t ucc_tl_mlx5_team_alltoall_init_start(ucc_tl_mlx5_team_t *team)
ucc_topo_t *topo;
ucc_status_t status;

if (team->dm_status[1] != UCC_OK) {
tl_debug(ctx->super.super.lib,
"node leader failed during device memory init: %s",
ucc_status_string(team->dm_status[1]));
return team->dm_status[1];
}

a2a = ucc_calloc(1, sizeof(*a2a), "mlx5_a2a");
if (!a2a) {
return UCC_ERR_NO_MEMORY;
Expand Down Expand Up @@ -277,19 +284,25 @@ ucc_tl_mlx5_team_alltoall_init_progress(ucc_tl_mlx5_team_t *tl_team)
ucc_tl_mlx5_team_t);
ucc_tl_mlx5_context_t *ctx = UCC_TL_MLX5_TEAM_CTX(team);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
ucc_rank_t node_size = a2a->node.sbgp->group_size;
ucc_rank_t node_rank = a2a->node.sbgp->group_rank;
ucc_base_lib_t *lib = UCC_TL_TEAM_LIB(team);
size_t op_seg_size = OP_SEGMENT_SIZE(a2a);
int i = 0;
net_exchange_t *local_data = NULL;
ucc_rank_t node_size, node_rank;
ucc_status_t status;
ucc_tl_mlx5_alltoall_op_t *op;
int j, asr_cq_size, net_size, ret;
struct ibv_port_attr port_attr;
size_t local_data_size, umr_buf_size;
size_t op_seg_size, local_data_size, umr_buf_size;
net_exchange_t *global_data, *remote_data;

if (tl_team->a2a_status < 0) {
return tl_team->a2a_status;
}

node_size = a2a->node.sbgp->group_size;
node_rank = a2a->node.sbgp->group_rank;
op_seg_size = OP_SEGMENT_SIZE(a2a);

switch (a2a->state) {
case TL_MLX5_ALLTOALL_STATE_SHMID:
status = ucc_service_coll_test(team->scoll_req);
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,14 @@ typedef enum

typedef struct ucc_tl_mlx5_team {
ucc_tl_team_t super;
ucc_status_t status[2];
ucc_service_coll_req_t *scoll_req;
ucc_tl_mlx5_team_state_t state;
ucc_status_t dm_status[2];
void *dm_offset;
ucc_mpool_t dm_pool;
struct ibv_dm *dm_ptr;
struct ibv_mr *dm_mr;
ucc_status_t a2a_status;
ucc_tl_mlx5_alltoall_t *a2a;
ucc_topo_t *topo;
ucc_ep_map_t ctx_map;
Expand Down
6 changes: 5 additions & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

*task_h = &(task->super);

tl_debug(UCC_TASK_LIB(task), "init coll task %p", task);
Expand Down Expand Up @@ -71,6 +71,10 @@ ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args,

switch (coll_args->args.coll_type) {
case UCC_COLL_TYPE_ALLTOALL:
status = ucc_derived_of(team, ucc_tl_mlx5_team_t)->a2a_status;
if (status != UCC_OK) {
return status;
}
status = ucc_tl_mlx5_alltoall_init(coll_args, team, task_h);
break;
case UCC_COLL_TYPE_BCAST:
Expand Down
64 changes: 39 additions & 25 deletions src/components/tl/mlx5/tl_mlx5_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,26 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t,
return status;
}

status = tl_mlx5_rcache_create(self);
if (UCC_OK != status) {
tl_error(self->super.super.lib, "failed to create rcache");
goto err_rcache;
}

status = ucc_tl_mlx5_mcast_context_init(&(self->mcast), &(self->cfg.mcast_ctx_conf));
if (UCC_OK != status) {
tl_error(self->super.super.lib,
"failed to initialize mcast context");
return status;
goto err_mcast_context;
}

tl_debug(self->super.super.lib, "initialized tl context: %p", self);
return UCC_OK;

err_mcast_context:
ucc_rcache_destroy(self->rcache);
err_rcache:
ucc_mpool_cleanup(&self->req_mp, 1);
return status;
}

UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_context_t)
Expand All @@ -60,7 +71,7 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_mlx5_context_t)
}

if (ucc_tl_mlx5_remove_shared_ctx_pd(self) != UCC_OK) {
tl_error(self->super.super.lib, "failed to free ib ctx and pd");
tl_debug(self->super.super.lib, "failed to free ib ctx and pd");
};

ucc_mpool_cleanup(&self->req_mp, 1);
Expand Down Expand Up @@ -142,7 +153,7 @@ typedef struct ucc_tl_mlx5_context_create_sbcast_data {
char sock_path[];
} ucc_tl_mlx5_context_create_sbcast_data_t;

ucc_status_t ucc_tl_mlx5_context_create_epilog(ucc_base_context_t *context)
ucc_status_t ucc_tl_mlx5_context_ib_ctx_pd_setup(ucc_base_context_t *context)
{
ucc_tl_mlx5_context_t *ctx = ucc_derived_of(context, ucc_tl_mlx5_context_t);
ucc_context_t * core_ctx = context->ucc_context;
Expand Down Expand Up @@ -182,37 +193,36 @@ ucc_status_t ucc_tl_mlx5_context_create_epilog(ucc_base_context_t *context)

status = ucc_topo_init(s, core_ctx->topo, &topo);
if (UCC_OK != status) {
tl_error(context->lib, "failed to init mlx5 ctx topo");
tl_debug(context->lib, "failed to init mlx5 ctx topo");
goto err_topo;
}

sbgp = ucc_topo_get_sbgp(topo, UCC_SBGP_NODE);
if (sbgp->status != UCC_SBGP_ENABLED) {
status = UCC_OK;
goto err;
}

ctx->shared_ctx = NULL;
ctx->shared_pd = NULL;
ctx->is_imported = sbgp->group_rank != PD_OWNER_RANK;

if (!ctx->is_imported) {
status = ucc_tl_mlx5_ib_ctx_pd_init(ctx);
if (status != UCC_OK) {
tl_debug(context->lib, "failed to init ib_ctx and pd");
goto err_ib_ctx_pd_init;
}
if (sbgp->status == UCC_SBGP_NOT_EXISTS) {
goto topo_ppn_1;
}
ucc_strncpy_safe(sock_path, template, sock_dir_len);
if (mkdtemp(sock_path) != NULL) {
status = ucc_tl_mlx5_ib_ctx_pd_init(ctx);
if (status != UCC_OK) {
goto err;
}

strncat(sock_path, sockname, sizeof(sock_path) - strlen(sock_path) - 1);
status = ucc_tl_mlx5_socket_init(ctx, sbgp->group_size, &sock,
sock_path);
if (UCC_OK != status) {
sock_path[0] = '\0';
tl_error(context->lib, "failed to init socket to share ib_ctx");
tl_debug(context->lib, "failed to init socket to share ib_ctx");
}
} else {
tl_error(context->lib, "failed to create tmp file for socket path");
tl_debug(context->lib, "failed to create tmp file for socket path");
sock_path[0] = '\0';
}
sbcast_data->ib_port = ctx->ib_port;
Expand Down Expand Up @@ -244,6 +254,7 @@ ucc_status_t ucc_tl_mlx5_context_create_epilog(ucc_base_context_t *context)
memcpy(sock_path, sbcast_data->sock_path, sizeof(sock_path));

if (strlen(sock_path) == 0) {
tl_debug(context->lib, "failed to share ctx and pd");
status = UCC_ERR_NO_MESSAGE;
goto err;
}
Expand All @@ -255,26 +266,29 @@ ucc_status_t ucc_tl_mlx5_context_create_epilog(ucc_base_context_t *context)
rmdir(sock_path);
}
if (status != UCC_OK) {
tl_error(context->lib, "failed to share ctx and pd");
goto err;
}

status = tl_mlx5_rcache_create(ctx);
if (UCC_OK != status) {
tl_error(context->lib, "failed to create rcache");
tl_debug(context->lib, "failed to share ctx and pd");
goto err;
}

ucc_free(sbcast_data);
ucc_topo_cleanup(topo);
close(sock);
topo_ppn_1:
ucc_topo_cleanup(topo);
tl_debug(ctx->super.super.lib, "initialized tl context: %p", ctx);
return UCC_OK;

err:
ucc_tl_mlx5_remove_shared_ctx_pd(ctx);
ucc_topo_cleanup(topo);
close(sock);
err_ib_ctx_pd_init:
ucc_topo_cleanup(topo);
err_topo:
ucc_free(sbcast_data);
tl_debug(ctx->super.super.lib, "failed initialize tl context: %p", ctx);
return status;
}

ucc_status_t ucc_tl_mlx5_context_create_epilog(ucc_base_context_t *context)
{
return ucc_tl_mlx5_context_ib_ctx_pd_setup(context);
}
43 changes: 17 additions & 26 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context,
}
}

self->status[0] = status;
self->state = TL_MLX5_TEAM_STATE_INIT;
self->dm_status[0] = status;
self->state = TL_MLX5_TEAM_STATE_INIT;

self->mcast = NULL;
status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), params,
Expand Down Expand Up @@ -116,9 +116,9 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)

switch (tl_team->state) {
case TL_MLX5_TEAM_STATE_INIT:
status = ucc_service_allreduce(
core_team, &tl_team->status[0], &tl_team->status[1],
UCC_DT_INT32, 1, UCC_OP_MIN, subset, &tl_team->scoll_req);
status = ucc_service_allreduce(core_team, &tl_team->dm_status[0],
&tl_team->dm_status[1], UCC_DT_INT32, 1,
UCC_OP_MIN, subset, &tl_team->scoll_req);
if (status < 0) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
"failed to collect global status");
Expand All @@ -136,33 +136,24 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team)
if (UCC_INPROGRESS == status) {
return status;
}
ucc_assert(status == UCC_OK);
ucc_service_coll_finalize(tl_team->scoll_req);
if (tl_team->status[1] != UCC_OK) {
tl_debug(UCC_TL_TEAM_LIB(tl_team),
"node leader failed during device memory init: %s",
ucc_status_string(tl_team->status[1]));
ucc_tl_mlx5_team_destroy(team);
return tl_team->status[1];
}
tl_team->state = TL_MLX5_TEAM_STATE_ALLTOALL_INIT;
case TL_MLX5_TEAM_STATE_ALLTOALL_INIT:
status = ucc_tl_mlx5_team_alltoall_init_start(tl_team);
if (status != UCC_OK) {
tl_debug(UCC_TL_TEAM_LIB(tl_team), "failed to init a2a: %s",
ucc_status_string(status));
return status;
}
tl_team->a2a_status = ucc_tl_mlx5_team_alltoall_init_start(tl_team);
tl_team->state = TL_MLX5_TEAM_STATE_ALLTOALL_POSTED;
case TL_MLX5_TEAM_STATE_ALLTOALL_POSTED:
status = ucc_tl_mlx5_team_alltoall_init_progress(tl_team);
}
if (status < 0) {
tl_debug(team->context->lib, "failed creating tl team: %p", tl_team);
} else if (status == UCC_OK) {
tl_debug(team->context->lib, "initialized tl team: %p", tl_team);
tl_team->a2a_status = ucc_tl_mlx5_team_alltoall_init_progress(tl_team);
if (tl_team->a2a_status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
}
if (tl_team->a2a_status != UCC_OK) {
tl_debug(UCC_TL_TEAM_LIB(tl_team), "failed to init a2a: %s",
ucc_status_string(tl_team->a2a_status));
}
}
return status;

tl_debug(team->context->lib, "initialized tl team: %p", tl_team);
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
Expand Down

0 comments on commit 1575241

Please sign in to comment.