diff --git a/prov/tcp/src/xnet.h b/prov/tcp/src/xnet.h index f822fb41f6d..ffece256b36 100644 --- a/prov/tcp/src/xnet.h +++ b/prov/tcp/src/xnet.h @@ -399,6 +399,10 @@ static inline void xnet_signal_progress(struct xnet_progress *progress) #define XNET_NEED_CTS BIT(11) #define XNET_MULTI_RECV FI_MULTI_RECV /* BIT(16) */ +struct xnet_mrecv { + size_t ref_cnt; +}; + struct xnet_xfer_entry { struct slist_entry entry; void *user_buf; @@ -409,9 +413,10 @@ struct xnet_xfer_entry { struct util_cntr *cntr; uint64_t tag_seq_no; uint64_t tag; - struct { - uint64_t ignore; - size_t rts_iov_cnt; + union { + uint64_t ignore; + size_t rts_iov_cnt; + struct xnet_mrecv *mrecv; }; fi_addr_t src_addr; uint64_t cq_flags; diff --git a/prov/tcp/src/xnet_cq.c b/prov/tcp/src/xnet_cq.c index 8e1e4274837..0702e84ba9e 100644 --- a/prov/tcp/src/xnet_cq.c +++ b/prov/tcp/src/xnet_cq.c @@ -151,6 +151,15 @@ void xnet_report_success(struct xnet_xfer_entry *xfer_entry) flags = xfer_entry->cq_flags & ~FI_COMPLETION; if (flags & FI_RECV) { len = xnet_msg_len(&xfer_entry->hdr); + if (xfer_entry->mrecv) { + xfer_entry->mrecv->ref_cnt--; + if (!xfer_entry->mrecv->ref_cnt) { + flags |= FI_MULTI_RECV; + free(xfer_entry->mrecv); + } + } else if (xfer_entry->ctrl_flags & XNET_MULTI_RECV) { + flags |= FI_MULTI_RECV; + } xnet_get_cq_info(xfer_entry, &flags, &data, &tag); } else if (flags & FI_REMOTE_CQ_DATA) { assert(flags & FI_REMOTE_WRITE); @@ -194,6 +203,15 @@ void xnet_report_error(struct xnet_xfer_entry *xfer_entry, int err) err_entry.flags = xfer_entry->cq_flags & ~FI_COMPLETION; if (err_entry.flags & FI_RECV) { + if (xfer_entry->mrecv) { + xfer_entry->mrecv->ref_cnt--; + if (!xfer_entry->mrecv->ref_cnt) { + err_entry.flags |= FI_MULTI_RECV; + free(xfer_entry->mrecv); + } + } else if (xfer_entry->ctrl_flags & XNET_MULTI_RECV) { + err_entry.flags |= FI_MULTI_RECV; + } xnet_get_cq_info(xfer_entry, &err_entry.flags, &err_entry.data, &err_entry.tag); } else if (err_entry.flags & FI_REMOTE_CQ_DATA) { diff --git a/prov/tcp/src/xnet_progress.c b/prov/tcp/src/xnet_progress.c index 342dcf72e28..13c244ba5e9 100644 --- a/prov/tcp/src/xnet_progress.c +++ b/prov/tcp/src/xnet_progress.c @@ -648,11 +648,22 @@ static int xnet_alter_mrecv(struct xnet_ep *ep, struct xnet_xfer_entry *xfer, if (!recv_entry) goto complete; + if (!xfer->mrecv) { + xfer->mrecv = calloc(1, sizeof(struct xnet_mrecv)); + if (!xfer->mrecv) { + xfer->cq_flags |= FI_MULTI_RECV; + return FI_SUCCESS; + } + xfer->mrecv->ref_cnt = 1; + } + recv_entry->ctrl_flags = XNET_MULTI_RECV; recv_entry->cq_flags = FI_MSG | FI_RECV; recv_entry->cntr = xfer->cntr; recv_entry->cq = xfer->cq; recv_entry->context = xfer->context; + recv_entry->mrecv = xfer->mrecv; + recv_entry->mrecv->ref_cnt++; recv_entry->iov_cnt = 1; recv_entry->user_buf = (char *) xfer->iov[0].iov_base + msg_len; @@ -663,7 +674,8 @@ static int xnet_alter_mrecv(struct xnet_ep *ep, struct xnet_xfer_entry *xfer, return 0; complete: - xfer->cq_flags |= FI_MULTI_RECV; + if (!xfer->mrecv) + xfer->cq_flags |= FI_MULTI_RECV; return ret; } diff --git a/prov/tcp/src/xnet_srx.c b/prov/tcp/src/xnet_srx.c index 9592ff69474..210b573c046 100644 --- a/prov/tcp/src/xnet_srx.c +++ b/prov/tcp/src/xnet_srx.c @@ -59,6 +59,7 @@ xnet_alloc_srx_xfer(struct xnet_srx *srx) if (xfer) { xfer->cntr = srx->cntr; xfer->cq = srx->cq; + xfer->mrecv = NULL; } return xfer; @@ -131,6 +132,7 @@ xnet_srx_recv(struct fid_ep *ep_fid, void *buf, size_t len, void *desc, srx = container_of(ep_fid, struct xnet_srx, rx_fid); + ofi_genlock_lock(xnet_srx2_progress(srx)->active_lock); recv_entry = xnet_alloc_srx_xfer(srx); if (!recv_entry) {