From c1c09cb7bd371662742222a817b1b8f685d750d9 Mon Sep 17 00:00:00 2001 From: Zach Dworkin Date: Fri, 6 Sep 2024 12:41:16 -0700 Subject: [PATCH] prov/shm: Add unmap_region function This function is mainly for the niche case where on progress_connreq a peer is added to the map with its region needing to be mapped, and then after mapping it, it's discovered that the newly mapped peer's process died. In this case we need to unmap them and free any resources that were opened for communicating with them. With the creation of this function we can rework smr_map_del to use it as common code. This requires changes to smr_av.c where smr_map_del is called. smr_map_del is now an iterable function. This is to optimize smr_map_cleanup to use ofi_rbmap_foreach to only cleanup peers that exist. Signed-off-by: Zach Dworkin --- prov/shm/src/smr_av.c | 35 ++++++++++++---- prov/shm/src/smr_progress.c | 7 +--- prov/shm/src/smr_util.c | 83 ++++++++++++++++++++++++------------- prov/shm/src/smr_util.h | 5 ++- 4 files changed, 87 insertions(+), 43 deletions(-) diff --git a/prov/shm/src/smr_av.c b/prov/shm/src/smr_av.c index 355d3bcad64..91e06801079 100644 --- a/prov/shm/src/smr_av.c +++ b/prov/shm/src/smr_av.c @@ -67,10 +67,13 @@ static int smr_map_init(const struct fi_provider *prov, struct smr_map *map, static void smr_map_cleanup(struct smr_map *map) { - int64_t i; + int ret; - for (i = 0; i < SMR_MAX_PEERS; i++) - smr_map_del(map, i); + ret = ofi_rbmap_foreach(&map->rbmap, map->rbmap.root, smr_map_del, + NULL); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove all entries from the map\n"); ofi_rbmap_cleanup(&map->rbmap); } @@ -115,6 +118,7 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, struct smr_ep *smr_ep; struct fid_peer_srx *srx; struct dlist_entry *av_entry; + struct ofi_rbnode *node; fi_addr_t util_addr; int64_t shm_id = -1; int i, ret; @@ -148,8 +152,17 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, if (ret) { if (fi_addr) fi_addr[i] = util_addr; - if (shm_id >= 0) - smr_map_del(&smr_av->smr_map, shm_id); + if (shm_id >= 0) { + node = ofi_rbmap_find(&smr_av->smr_map.rbmap, + &smr_av->smr_map.peers[shm_id].peer.name); + assert(node); + ret = smr_map_del(&smr_av->smr_map.rbmap, + node, NULL); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", + shm_id); + } continue; } @@ -190,6 +203,7 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count struct smr_av *smr_av; struct smr_ep *smr_ep; struct dlist_entry *av_entry; + struct ofi_rbnode *node; int i, ret = 0; int64_t id; @@ -207,11 +221,18 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count break; } - smr_map_del(&smr_av->smr_map, id); + node = ofi_rbmap_find(&smr_av->smr_map.rbmap, + &smr_av->smr_map.peers[id].peer.name); + assert(node); + ret = smr_map_del(&smr_av->smr_map.rbmap, node, NULL); + if (ret) { + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", id); + break; + } dlist_foreach(&util_av->ep_list, av_entry) { util_ep = container_of(av_entry, struct util_ep, av_entry); smr_ep = container_of(util_ep, struct smr_ep, util_ep); - smr_unmap_from_endpoint(smr_ep->region, id); if (smr_av->smr_map.num_peers > 0) smr_ep->region->max_sar_buf_per_peer = SMR_MAX_PEERS / diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index 141826b9bba..dfd219afcb8 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -891,13 +891,8 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) if (peer_smr->pid != (int) cmd->msg.hdr.data) { /* TODO track and update/complete in error any transfers * to or from old mapping - * - * TODO create smr_unmap_region - * this needs to close peer_smr->map->peers[idx].pid_fd - * This case will also return an unmapped region because the idx - * is valid but the region was unmapped */ - munmap(peer_smr, peer_smr->total_size); + smr_unmap_region(&smr_prov, ep->region->map, idx); smr_map_to_region(&smr_prov, ep->region->map, idx); peer_smr = smr_peer_region(ep->region, idx); } diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c index 2924ddaa6f2..047752f23c4 100644 --- a/prov/shm/src/smr_util.c +++ b/prov/shm/src/smr_util.c @@ -479,24 +479,67 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) return; } +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t peer_id) +{ + struct smr_region *peer_region; + struct smr_peer *peer; + struct util_ep *util_ep; + struct smr_ep *smr_ep; + struct smr_av *av; + int ret = 0; + + ofi_spin_lock(&map->lock); + peer_region = map->peers[peer_id].region; + if (!peer_region) + goto unlock; + + peer = &map->peers[peer_id]; + av = container_of(map, struct smr_av, smr_map); + dlist_foreach_container(&av->util_av.ep_list, struct util_ep, util_ep, + av_entry) { + smr_ep = container_of(util_ep, struct smr_ep, util_ep); + smr_unmap_from_endpoint(smr_ep->region, peer_id); + } + + if (map->flags & SMR_FLAG_HMEM_ENABLED) { + ret = ofi_hmem_host_unregister(peer_region); + if (ret) + FI_WARN(prov, FI_LOG_EP_CTRL, + "unable to unregister shm with iface\n"); + + if (peer->pid_fd != -1) { + close(peer->pid_fd); + peer->pid_fd = -1; + } + } + + munmap(peer_region, peer_region->total_size); + peer_region = NULL; + peer->region = NULL; + +unlock: + ofi_spin_unlock(&map->lock); +} + void smr_unmap_from_endpoint(struct smr_region *region, int64_t id) { struct smr_region *peer_smr; struct smr_peer_data *local_peers, *peer_peers; int64_t peer_id; - local_peers = smr_peer_data(region); if (region->map->peers[id].peer.id < 0) return; peer_smr = smr_peer_region(region, id); - peer_id = smr_peer_data(region)[id].addr.id; - + assert(peer_smr); peer_peers = smr_peer_data(peer_smr); + peer_id = smr_peer_data(region)[id].addr.id; peer_peers[peer_id].addr.id = -1; peer_peers[peer_id].name_sent = 0; + local_peers = smr_peer_data(region); ofi_xpmem_release(&local_peers[peer_id].xpmem); } @@ -544,40 +587,22 @@ int smr_map_add(const struct fi_provider *prov, struct smr_map *map, return FI_SUCCESS; } -void smr_map_del(struct smr_map *map, int64_t id) +int smr_map_del(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context) { - struct dlist_entry *entry; + struct smr_map *map = container_of(rbmap, struct smr_map, rbmap); + int64_t id = (uintptr_t) node->data; assert(id >= 0 && id < SMR_MAX_PEERS); - - pthread_mutex_lock(&ep_list_lock); - entry = dlist_find_first_match(&ep_name_list, smr_match_name, - smr_no_prefix(map->peers[id].peer.name)); - pthread_mutex_unlock(&ep_list_lock); - + smr_unmap_region(&smr_prov, map, id); ofi_spin_lock(&map->lock); - (void) ofi_rbmap_find_delete(&map->rbmap, - (void *) map->peers[id].peer.name); - + ofi_rbmap_delete(rbmap, node); map->peers[id].fiaddr = FI_ADDR_NOTAVAIL; map->peers[id].peer.id = -1; map->num_peers--; - - if (!map->peers[id].region) - goto unlock; - - if (!entry) { - if (map->flags & SMR_FLAG_HMEM_ENABLED) { - if (map->peers[id].pid_fd != -1) - close(map->peers[id].pid_fd); - - (void) ofi_hmem_host_unregister(map->peers[id].region); - } - munmap(map->peers[id].region, map->peers[id].region->total_size); - map->peers[id].region = NULL; - } -unlock: ofi_spin_unlock(&map->lock); + + return FI_SUCCESS; } struct smr_region *smr_map_get(struct smr_map *map, int64_t id) diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h index c5bf8124873..25d25531672 100644 --- a/prov/shm/src/smr_util.h +++ b/prov/shm/src/smr_util.h @@ -356,11 +356,14 @@ void smr_cleanup(void); int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, int64_t id); void smr_map_to_endpoint(struct smr_region *region, int64_t id); +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t id); void smr_unmap_from_endpoint(struct smr_region *region, int64_t id); void smr_exchange_all_peers(struct smr_region *region); int smr_map_add(const struct fi_provider *prov, struct smr_map *map, const char *name, int64_t *id); -void smr_map_del(struct smr_map *map, int64_t id); +int smr_map_del(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context); struct smr_region *smr_map_get(struct smr_map *map, int64_t id);