From a4b945501f18d04f39d7ffd1860c87298ba7e691 Mon Sep 17 00:00:00 2001 From: Rohit Zambre Date: Fri, 4 Sep 2020 01:17:28 -0500 Subject: [PATCH] hints infrastructure to negotiate bits for tag-based parallelism --- src/include/mpir_comm.h | 3 +++ src/mpi/comm/commutil.c | 9 +++++++++ src/mpid/ch4/include/mpidpre.h | 4 ++++ src/mpid/ch4/netmod/ucx/ucx_recv.h | 4 +++- src/mpid/ch4/src/ch4_comm.c | 29 ++++++++++++++++++++++++++++- src/mpid/ch4/src/ch4_recv.h | 8 ++------ src/mpid/ch4/src/ch4_send.h | 16 ++++------------ src/mpid/ch4/src/ch4_vci.h | 12 ++++++++---- 8 files changed, 61 insertions(+), 24 deletions(-) diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 63b7db64415..cb4e554a053 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -88,6 +88,9 @@ enum MPIR_COMM_HINT_PREDEFINED_t { MPIR_COMM_HINT_ALLOW_OVERTAKING, MPIR_COMM_HINT_NEW_VCI, MPIR_COMM_HINT_NUM_VCIS, + MPIR_COMM_HINT_TAG_PAR, + MPIR_COMM_HINT_NUM_TAG_BITS_VCI, + MPIR_COMM_HINT_NUM_TAG_BITS_APP, /* device specific hints. * Potentially, we can use macros and configure to hide them */ MPIR_COMM_HINT_EAGER_THRESH, /* ch3 */ diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index d4a87060942..0158a4c2165 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -164,6 +164,12 @@ void MPIR_Comm_hint_init(void) NULL, MPIR_COMM_HINT_TYPE_BOOL, 0); MPIR_Comm_register_hint(MPIR_COMM_HINT_NUM_VCIS, "mpi_num_vcis", NULL, MPIR_COMM_HINT_TYPE_INT, 0); + MPIR_Comm_register_hint(MPIR_COMM_HINT_TAG_PAR, "mpi_assert_tag_based_parallelism", + NULL, MPIR_COMM_HINT_TYPE_BOOL, 0); + MPIR_Comm_register_hint(MPIR_COMM_HINT_NUM_TAG_BITS_VCI, "mpi_num_tag_bits_for_vci", + NULL, MPIR_COMM_HINT_TYPE_INT, 0); + MPIR_Comm_register_hint(MPIR_COMM_HINT_NUM_TAG_BITS_APP, "mpi_num_tag_bits_for_app", + NULL, MPIR_COMM_HINT_TYPE_INT, 0); } /* FIXME : @@ -764,6 +770,9 @@ int MPII_Comm_copy(MPIR_Comm * comm_ptr, int size, MPIR_Info * info, MPIR_Comm * /* Since hints are never propogated starting MPI-3.2, set defaults for hints */ newcomm_ptr->hints[MPIR_COMM_HINT_NEW_VCI] = FALSE; newcomm_ptr->hints[MPIR_COMM_HINT_NUM_VCIS] = 1; + newcomm_ptr->hints[MPIR_COMM_HINT_TAG_PAR] = FALSE; + newcomm_ptr->hints[MPIR_COMM_HINT_NUM_TAG_BITS_VCI] = 0; + newcomm_ptr->hints[MPIR_COMM_HINT_NUM_TAG_BITS_APP] = 0; if (info) { MPII_Comm_set_hints(newcomm_ptr, info); } diff --git a/src/mpid/ch4/include/mpidpre.h b/src/mpid/ch4/include/mpidpre.h index d3bf86f4246..d4d7781ad8b 100644 --- a/src/mpid/ch4/include/mpidpre.h +++ b/src/mpid/ch4/include/mpidpre.h @@ -505,6 +505,9 @@ typedef struct MPIDI_vci_hash { } multi; } u; int count; + int tag_par; + int num_tag_bits_for_vci; + int num_tag_bits_for_app; } MPIDI_vci_hash_t; typedef struct MPIDI_Devcomm_t { @@ -532,6 +535,7 @@ typedef struct MPIDI_Devcomm_t { #define MPIDI_COMM_VCI(comm) ((comm)->dev.ch4).vci_hash.u.single.vci #define MPIDI_COMM_MULTI_VCI(comm) ((comm)->dev.ch4).vci_hash.u.multi.vci #define MPIDI_COMM_VCI_COUNT(comm) ((comm)->dev.ch4).vci_hash.count +#define MPIDI_COMM_VCI_HASH(comm) ((comm)->dev.ch4).vci_hash typedef struct { union { diff --git a/src/mpid/ch4/netmod/ucx/ucx_recv.h b/src/mpid/ch4/netmod/ucx/ucx_recv.h index 2da4fd6e558..6e6ef4d5120 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_recv.h +++ b/src/mpid/ch4/netmod/ucx/ucx_recv.h @@ -36,7 +36,9 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_UCX_recv_cmpl_cb(void *request, ucs_status_t if (MPIDI_COMM_VCI_COUNT(comm_ptr) == 1) { vci = MPIDI_COMM_VCI(comm_ptr); } else { - vci = (MPIDI_UCX_get_tag(info->sender_tag) == MPI_ANY_TAG) ? 0 : ((MPIDI_UCX_get_tag(info->sender_tag) >> 5) & 0x1f); + int tag = MPIDI_UCX_get_tag(info->sender_tag); + vci = (tag == MPI_ANY_TAG) ? 0 : ((tag >> MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app) + & ((1 << MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci) - 1)); } /* Does not require safety since we are at ucp_tag_recv_nb */ diff --git a/src/mpid/ch4/src/ch4_comm.c b/src/mpid/ch4/src/ch4_comm.c index 304e45b03a3..829f9ecd4d4 100644 --- a/src/mpid/ch4/src/ch4_comm.c +++ b/src/mpid/ch4/src/ch4_comm.c @@ -188,8 +188,13 @@ int MPID_Comm_create_hook(MPIR_Comm * comm) MPIDIU_avt_add_ref(MPIDI_COMM(comm, local_map).avtid); } + /* Initialize some fields */ + MPIDI_COMM_VCI_HASH(comm).tag_par = 0; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = 0; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = 0; + if (MPIR_CONTEXT_READ_FIELD(SUBCOMM, comm->context_id)) { - /* If this is a subcommunicator, then use the VCI(s) of the parent*/ + /* If this is a subcommunicator, then use the VCI policies of the parent*/ int parent_comm_num_vcis; MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_POOL(lock)); @@ -205,6 +210,11 @@ int MPID_Comm_create_hook(MPIR_Comm * comm) MPIDI_COMM_MULTI_VCI(comm)[vci_i] = MPIDI_COMM_MULTI_VCI(comm->parent_comm)[vci_i]; MPIDI_VCI(MPIDI_COMM_MULTI_VCI(comm->parent_comm)[vci_i]).ref_count++; } + + MPIDI_COMM_VCI_HASH(comm).tag_par = MPIDI_COMM_VCI_HASH(comm->parent_comm).tag_par; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = MPIDI_COMM_VCI_HASH(comm->parent_comm).num_tag_bits_for_vci; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = MPIDI_COMM_VCI_HASH(comm->parent_comm).num_tag_bits_for_app; + /* FIXME: this is a hack */ MPIDI_COMM_VCI(comm) = MPIDI_VCI_ROOT; } @@ -267,6 +277,18 @@ int MPID_Comm_create_hook(MPIR_Comm * comm) } MPIDI_COMM_MULTI_VCI(comm)[vci_i] = vci; } + + if (comm->hints[MPIR_COMM_HINT_TAG_PAR]) { + printf("Tag parallelism hinted for this comm\n"); + MPIDI_COMM_VCI_HASH(comm).tag_par = 1; + + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = comm->hints[MPIR_COMM_HINT_NUM_TAG_BITS_VCI]; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = comm->hints[MPIR_COMM_HINT_NUM_TAG_BITS_APP]; + + MPIR_Assert(((1 << (2*MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app)) - 1) + <= MPIR_Process.attrs.tag_ub); + } + /* FIXME: this is a hack */ MPIDI_COMM_VCI(comm) = MPIDI_VCI_ROOT; } @@ -289,6 +311,11 @@ int MPID_Comm_create_hook(MPIR_Comm * comm) MPIDI_COMM_MULTI_VCI(comm)[vci_i] = MPIDI_COMM_MULTI_VCI(comm->orig_comm)[vci_i]; MPIDI_VCI(MPIDI_COMM_MULTI_VCI(comm->orig_comm)[vci_i]).ref_count++; } + + MPIDI_COMM_VCI_HASH(comm).tag_par = MPIDI_COMM_VCI_HASH(comm->orig_comm).tag_par; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_vci = MPIDI_COMM_VCI_HASH(comm->orig_comm).num_tag_bits_for_vci; + MPIDI_COMM_VCI_HASH(comm).num_tag_bits_for_app = MPIDI_COMM_VCI_HASH(comm->orig_comm).num_tag_bits_for_app; + /* FIXME: this is a hack */ MPIDI_COMM_VCI(comm) = MPIDI_VCI_ROOT; } diff --git a/src/mpid/ch4/src/ch4_recv.h b/src/mpid/ch4/src/ch4_recv.h index d2939051783..bbf556dcf2c 100644 --- a/src/mpid/ch4/src/ch4_recv.h +++ b/src/mpid/ch4/src/ch4_recv.h @@ -359,15 +359,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Recv(void *buf, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; - int num_vcis; int hst_vci, rmt_vci; MPIDI_av_entry_t *av = NULL; MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_RECV); MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_RECV); /*TODO: remove this branch to use function pointers */ - num_vcis = MPIDI_COMM_VCI_COUNT(comm); - if (num_vcis == 1) { + if (!MPIDI_COMM_VCI_HASH(comm).tag_par) { hst_vci = MPIDI_vci_get(comm, rank, tag); rmt_vci = hst_vci; } else { @@ -539,15 +537,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Irecv(void *buf, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; - int num_vcis; int hst_vci, rmt_vci; MPIDI_av_entry_t *av = NULL; MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_IRECV); MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_IRECV); /*TODO: remove this branch to use function pointers */ - num_vcis = MPIDI_COMM_VCI_COUNT(comm); - if (num_vcis == 1) { + if (!MPIDI_COMM_VCI_HASH(comm).tag_par) { hst_vci = MPIDI_vci_get(comm, rank, tag); rmt_vci = hst_vci; } else { diff --git a/src/mpid/ch4/src/ch4_send.h b/src/mpid/ch4/src/ch4_send.h index e14aeb1847f..759d1ba4b8a 100644 --- a/src/mpid/ch4/src/ch4_send.h +++ b/src/mpid/ch4/src/ch4_send.h @@ -329,15 +329,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Send(const void *buf, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; - int num_vcis; int hst_vci, rmt_vci; MPIDI_av_entry_t *av = NULL; MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_SEND); MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_SEND); /*TODO: remove this branch to use function pointers */ - num_vcis = MPIDI_COMM_VCI_COUNT(comm); - if (num_vcis == 1) { + if (!MPIDI_COMM_VCI_HASH(comm).tag_par) { hst_vci = MPIDI_vci_get(comm, rank, tag); rmt_vci = hst_vci; } else { @@ -375,15 +373,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Isend(const void *buf, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; - int num_vcis; int hst_vci, rmt_vci; MPIDI_av_entry_t *av = NULL; MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_ISEND); MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_ISEND); /*TODO: remove this branch to use function pointers */ - num_vcis = MPIDI_COMM_VCI_COUNT(comm); - if (num_vcis == 1) { + if (!MPIDI_COMM_VCI_HASH(comm).tag_par) { hst_vci = MPIDI_vci_get(comm, rank, tag); rmt_vci = hst_vci; } else { @@ -430,15 +426,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Rsend(const void *buf, */ int mpi_errno = MPI_SUCCESS; - int num_vcis; int hst_vci, rmt_vci; MPIDI_av_entry_t *av = NULL; MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_RSEND); MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_RSEND); /*TODO: remove this branch to use function pointers */ - num_vcis = MPIDI_COMM_VCI_COUNT(comm); - if (num_vcis == 1) { + if (!MPIDI_COMM_VCI_HASH(comm).tag_par) { hst_vci = MPIDI_vci_get(comm, rank, tag); rmt_vci = hst_vci; } else { @@ -482,15 +476,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Irsend(const void *buf, */ int mpi_errno = MPI_SUCCESS; - int num_vcis; int hst_vci, rmt_vci; MPIDI_av_entry_t *av = NULL; MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPID_IRSEND); MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPID_IRSEND); /*TODO: remove this branch to use function pointers */ - num_vcis = MPIDI_COMM_VCI_COUNT(comm); - if (num_vcis == 1) { + if (!MPIDI_COMM_VCI_HASH(comm).tag_par) { hst_vci = MPIDI_vci_get(comm, rank, tag); rmt_vci = hst_vci; } else { diff --git a/src/mpid/ch4/src/ch4_vci.h b/src/mpid/ch4/src/ch4_vci.h index 3682c1d15fd..d0867334e88 100644 --- a/src/mpid/ch4/src/ch4_vci.h +++ b/src/mpid/ch4/src/ch4_vci.h @@ -29,22 +29,26 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get(MPIR_Comm * comm_ptr, int rank, int t MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_host_sends(MPIR_Comm * comm_ptr, int rank, int tag) { - return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f); + return (tag == MPI_ANY_TAG) ? 0 : (tag >> (MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app + + MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci)); } MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_remote_sends(MPIR_Comm * comm_ptr, int rank, int tag) { - return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 5) & 0x1f); + return (tag == MPI_ANY_TAG) ? 0 : ((tag >> MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app) + & ((1 << MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci) - 1)); } MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_host_recvs(MPIR_Comm * comm_ptr, int rank, int tag) { - return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 5) & 0x1f); + return (tag == MPI_ANY_TAG) ? 0 : ((tag >> MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app) + & ((1 << MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci) - 1)); } MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_remote_recvs(MPIR_Comm * comm_ptr, int rank, int tag) { - return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f); + return (tag == MPI_ANY_TAG) ? 0 : (tag >> (MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_app + + MPIDI_COMM_VCI_HASH(comm_ptr).num_tag_bits_for_vci)); } MPL_STATIC_INLINE_PREFIX int MPIDI_vci_get_with_tid(MPIR_Comm * comm)