Skip to content

Commit

Permalink
hints infrastructure to negotiate bits for tag-based parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
rzambre committed Dec 23, 2020
1 parent 71bacd8 commit a4b9455
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 24 deletions.
3 changes: 3 additions & 0 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ enum MPIR_COMM_HINT_PREDEFINED_t {
MPIR_COMM_HINT_ALLOW_OVERTAKING,
MPIR_COMM_HINT_NEW_VCI,
MPIR_COMM_HINT_NUM_VCIS,
MPIR_COMM_HINT_TAG_PAR,
MPIR_COMM_HINT_NUM_TAG_BITS_VCI,
MPIR_COMM_HINT_NUM_TAG_BITS_APP,
/* device specific hints.
* Potentially, we can use macros and configure to hide them */
MPIR_COMM_HINT_EAGER_THRESH, /* ch3 */
Expand Down
9 changes: 9 additions & 0 deletions src/mpi/comm/commutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ void MPIR_Comm_hint_init(void)
NULL, MPIR_COMM_HINT_TYPE_BOOL, 0);
MPIR_Comm_register_hint(MPIR_COMM_HINT_NUM_VCIS, "mpi_num_vcis",
NULL, MPIR_COMM_HINT_TYPE_INT, 0);
MPIR_Comm_register_hint(MPIR_COMM_HINT_TAG_PAR, "mpi_assert_tag_based_parallelism",
NULL, MPIR_COMM_HINT_TYPE_BOOL, 0);
MPIR_Comm_register_hint(MPIR_COMM_HINT_NUM_TAG_BITS_VCI, "mpi_num_tag_bits_for_vci",
NULL, MPIR_COMM_HINT_TYPE_INT, 0);
MPIR_Comm_register_hint(MPIR_COMM_HINT_NUM_TAG_BITS_APP, "mpi_num_tag_bits_for_app",
NULL, MPIR_COMM_HINT_TYPE_INT, 0);
}

/* FIXME :
Expand Down Expand Up @@ -764,6 +770,9 @@ int MPII_Comm_copy(MPIR_Comm * comm_ptr, int size, MPIR_Info * info, MPIR_Comm *
/* Since hints are never propogated starting MPI-3.2, set defaults for hints */
newcomm_ptr->hints[MPIR_COMM_HINT_NEW_VCI] = FALSE;
newcomm_ptr->hints[MPIR_COMM_HINT_NUM_VCIS] = 1;
newcomm_ptr->hints[MPIR_COMM_HINT_TAG_PAR] = FALSE;
newcomm_ptr->hints[MPIR_COMM_HINT_NUM_TAG_BITS_VCI] = 0;
newcomm_ptr->hints[MPIR_COMM_HINT_NUM_TAG_BITS_APP] = 0;
if (info) {
MPII_Comm_set_hints(newcomm_ptr, info);
}
Expand Down
4 changes: 4 additions & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,9 @@ typedef struct MPIDI_vci_hash {
} multi;
} u;
int count;
int tag_par;
int num_tag_bits_for_vci;
int num_tag_bits_for_app;
} MPIDI_vci_hash_t;

typedef struct MPIDI_Devcomm_t {
Expand Down Expand Up @@ -532,6 +535,7 @@ typedef struct MPIDI_Devcomm_t {
#define MPIDI_COMM_VCI(comm) ((comm)->dev.ch4).vci_hash.u.single.vci
#define MPIDI_COMM_MULTI_VCI(comm) ((comm)->dev.ch4).vci_hash.u.multi.vci
#define MPIDI_COMM_VCI_COUNT(comm) ((comm)->dev.ch4).vci_hash.count
#define MPIDI_COMM_VCI_HASH(comm) ((comm)->dev.ch4).vci_hash

typedef struct {
union {
Expand Down
4 changes: 3 additions & 1 deletion src/mpid/ch4/netmod/ucx/ucx_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_UCX_recv_cmpl_cb(void *request, ucs_status_t
if (MPIDI_COMM_VCI_COUNT(comm_ptr) == 1) {
vci = MPIDI_COMM_VCI(comm_ptr);
} else {
vci = (MPIDI_UCX_get_tag(info->sender_tag) == MPI_ANY_TAG) ? 0 : ((MPIDI_UCX_get_tag(info->sender_tag) >> 5) & 0x1f);
int tag = MPIDI_UCX_get_tag(info->sender_tag);
vci = (tag == MPI_ANY_TAG) ? 0 : ((tag >> MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app)
& ((1 << MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci) - 1));
}

/* Does not require safety since we are at ucp_tag_recv_nb */
Expand Down
29 changes: 28 additions & 1 deletion src/mpid/ch4/src/ch4_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,13 @@ int MPID_Comm_create_hook(MPIR_Comm * comm)
MPIDIU_avt_add_ref(MPIDI_COMM(comm, local_map).avtid);
}

/* Initialize some fields */
MPIDI_COMM_VCI_HASH(comm).tag_par = 0;
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = 0;
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = 0;

if (MPIR_CONTEXT_READ_FIELD(SUBCOMM, comm->context_id)) {
/* If this is a subcommunicator, then use the VCI(s) of the parent*/
/* If this is a subcommunicator, then use the VCI policies of the parent*/
int parent_comm_num_vcis;
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_POOL(lock));

Expand All @@ -205,6 +210,11 @@ int MPID_Comm_create_hook(MPIR_Comm * comm)
MPIDI_COMM_MULTI_VCI(comm)[vci_i] = MPIDI_COMM_MULTI_VCI(comm->parent_comm)[vci_i];
MPIDI_VCI(MPIDI_COMM_MULTI_VCI(comm->parent_comm)[vci_i]).ref_count++;
}

MPIDI_COMM_VCI_HASH(comm).tag_par = MPIDI_COMM_VCI_HASH(comm->parent_comm).tag_par;
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = MPIDI_COMM_VCI_HASH(comm->parent_comm).num_tag_bits_for_vci;
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = MPIDI_COMM_VCI_HASH(comm->parent_comm).num_tag_bits_for_app;

/* FIXME: this is a hack */
MPIDI_COMM_VCI(comm) = MPIDI_VCI_ROOT;
}
Expand Down Expand Up @@ -267,6 +277,18 @@ int MPID_Comm_create_hook(MPIR_Comm * comm)
}
MPIDI_COMM_MULTI_VCI(comm)[vci_i] = vci;
}

if (comm->hints[MPIR_COMM_HINT_TAG_PAR]) {
printf("Tag parallelism hinted for this comm\n");
MPIDI_COMM_VCI_HASH(comm).tag_par = 1;

MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = comm->hints[MPIR_COMM_HINT_NUM_TAG_BITS_VCI];
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = comm->hints[MPIR_COMM_HINT_NUM_TAG_BITS_APP];

MPIR_Assert(((1 << (2*MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app)) - 1)
<= MPIR_Process.attrs.tag_ub);
}

/* FIXME: this is a hack */
MPIDI_COMM_VCI(comm) = MPIDI_VCI_ROOT;
}
Expand All @@ -289,6 +311,11 @@ int MPID_Comm_create_hook(MPIR_Comm * comm)
MPIDI_COMM_MULTI_VCI(comm)[vci_i] = MPIDI_COMM_MULTI_VCI(comm->orig_comm)[vci_i];
MPIDI_VCI(MPIDI_COMM_MULTI_VCI(comm->orig_comm)[vci_i]).ref_count++;
}

MPIDI_COMM_VCI_HASH(comm).tag_par = MPIDI_COMM_VCI_HASH(comm->orig_comm).tag_par;
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = MPIDI_COMM_VCI_HASH(comm->orig_comm).num_tag_bits_for_vci;
MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = MPIDI_COMM_VCI_HASH(comm->orig_comm).num_tag_bits_for_app;

/* FIXME: this is a hack */
MPIDI_COMM_VCI(comm) = MPIDI_VCI_ROOT;
}
Expand Down
8 changes: 2 additions & 6 deletions src/mpid/ch4/src/ch4_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Recv(void *buf,
MPIR_Request ** request)
{
int mpi_errno = MPI_SUCCESS;
int num_vcis;
int hst_vci, rmt_vci;
MPIDI_av_entry_t *av = NULL;
MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_RECV);
MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_RECV);

/*TODO: remove this branch to use function pointers */
num_vcis = MPIDI_COMM_VCI_COUNT(comm);
if (num_vcis == 1) {
if (!MPIDI_COMM_VCI_HASH(comm).tag_par) {
hst_vci = MPIDI_vci_get(comm, rank, tag);
rmt_vci = hst_vci;
} else {
Expand Down Expand Up @@ -539,15 +537,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Irecv(void *buf,
MPIR_Request ** request)
{
int mpi_errno = MPI_SUCCESS;
int num_vcis;
int hst_vci, rmt_vci;
MPIDI_av_entry_t *av = NULL;
MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_IRECV);
MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_IRECV);

/*TODO: remove this branch to use function pointers */
num_vcis = MPIDI_COMM_VCI_COUNT(comm);
if (num_vcis == 1) {
if (!MPIDI_COMM_VCI_HASH(comm).tag_par) {
hst_vci = MPIDI_vci_get(comm, rank, tag);
rmt_vci = hst_vci;
} else {
Expand Down
16 changes: 4 additions & 12 deletions src/mpid/ch4/src/ch4_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Send(const void *buf,
MPIR_Request ** request)
{
int mpi_errno = MPI_SUCCESS;
int num_vcis;
int hst_vci, rmt_vci;
MPIDI_av_entry_t *av = NULL;
MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_SEND);
MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_SEND);

/*TODO: remove this branch to use function pointers */
num_vcis = MPIDI_COMM_VCI_COUNT(comm);
if (num_vcis == 1) {
if (!MPIDI_COMM_VCI_HASH(comm).tag_par) {
hst_vci = MPIDI_vci_get(comm, rank, tag);
rmt_vci = hst_vci;
} else {
Expand Down Expand Up @@ -375,15 +373,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Isend(const void *buf,
MPIR_Request ** request)
{
int mpi_errno = MPI_SUCCESS;
int num_vcis;
int hst_vci, rmt_vci;
MPIDI_av_entry_t *av = NULL;
MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_ISEND);
MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_ISEND);

/*TODO: remove this branch to use function pointers */
num_vcis = MPIDI_COMM_VCI_COUNT(comm);
if (num_vcis == 1) {
if (!MPIDI_COMM_VCI_HASH(comm).tag_par) {
hst_vci = MPIDI_vci_get(comm, rank, tag);
rmt_vci = hst_vci;
} else {
Expand Down Expand Up @@ -430,15 +426,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Rsend(const void *buf,
*/

int mpi_errno = MPI_SUCCESS;
int num_vcis;
int hst_vci, rmt_vci;
MPIDI_av_entry_t *av = NULL;
MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_RSEND);
MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_RSEND);

/*TODO: remove this branch to use function pointers */
num_vcis = MPIDI_COMM_VCI_COUNT(comm);
if (num_vcis == 1) {
if (!MPIDI_COMM_VCI_HASH(comm).tag_par) {
hst_vci = MPIDI_vci_get(comm, rank, tag);
rmt_vci = hst_vci;
} else {
Expand Down Expand Up @@ -482,15 +476,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Irsend(const void *buf,
*/

int mpi_errno = MPI_SUCCESS;
int num_vcis;
int hst_vci, rmt_vci;
MPIDI_av_entry_t *av = NULL;
MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_IRSEND);
MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_IRSEND);

/*TODO: remove this branch to use function pointers */
num_vcis = MPIDI_COMM_VCI_COUNT(comm);
if (num_vcis == 1) {
if (!MPIDI_COMM_VCI_HASH(comm).tag_par) {
hst_vci = MPIDI_vci_get(comm, rank, tag);
rmt_vci = hst_vci;
} else {
Expand Down
12 changes: 8 additions & 4 deletions src/mpid/ch4/src/ch4_vci.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get(MPIR_Comm * comm_ptr, int rank, int t

MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_host_sends(MPIR_Comm * comm_ptr, int rank, int tag)
{
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f);
return (tag == MPI_ANY_TAG) ? 0 : (tag >> (MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app
+ MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci));
}

MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_remote_sends(MPIR_Comm * comm_ptr, int rank, int tag)
{
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 5) & 0x1f);
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app)
& ((1 << MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci) - 1));
}

MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_host_recvs(MPIR_Comm * comm_ptr, int rank, int tag)
{
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 5) & 0x1f);
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app)
& ((1 << MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci) - 1));
}

MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_remote_recvs(MPIR_Comm * comm_ptr, int rank, int tag)
{
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f);
return (tag == MPI_ANY_TAG) ? 0 : (tag >> (MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app
+ MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci));
}

MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_with_tid(MPIR_Comm * comm)
Expand Down

0 comments on commit a4b9455

Please sign in to comment.