diff --git a/prov/shm/src/smr_av.c b/prov/shm/src/smr_av.c index 355d3bcad64..dc69de901a7 100644 --- a/prov/shm/src/smr_av.c +++ b/prov/shm/src/smr_av.c @@ -67,10 +67,13 @@ static int smr_map_init(const struct fi_provider *prov, struct smr_map *map, static void smr_map_cleanup(struct smr_map *map) { - int64_t i; + int ret; - for (i = 0; i < SMR_MAX_PEERS; i++) - smr_map_del(map, i); + ret = ofi_rbmap_foreach(&map->rbmap, map->rbmap.root, smr_map_unmap, + NULL); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove all entries from the map\n"); ofi_rbmap_cleanup(&map->rbmap); } @@ -115,6 +118,7 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, struct smr_ep *smr_ep; struct fid_peer_srx *srx; struct dlist_entry *av_entry; + struct ofi_rbnode *node; fi_addr_t util_addr; int64_t shm_id = -1; int i, ret; @@ -148,8 +152,17 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, if (ret) { if (fi_addr) fi_addr[i] = util_addr; - if (shm_id >= 0) - smr_map_del(&smr_av->smr_map, shm_id); + if (shm_id >= 0) { + node = ofi_rbmap_find(&smr_av->smr_map.rbmap, + &smr_av->smr_map.peers[shm_id].peer.name); + assert(node); + ret = smr_map_del(&smr_av->smr_map.rbmap, + node, NULL); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", + shm_id); + } continue; } @@ -190,6 +203,7 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count struct smr_av *smr_av; struct smr_ep *smr_ep; struct dlist_entry *av_entry; + struct ofi_rbnode *node; int i, ret = 0; int64_t id; @@ -207,11 +221,18 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count break; } - smr_map_del(&smr_av->smr_map, id); + node = ofi_rbmap_find(&smr_av->smr_map.rbmap, + &smr_av->smr_map.peers[id].peer.name); + assert(node); + ret = smr_map_del(&smr_av->smr_map.rbmap, node, NULL); + if (ret) { + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", id); + break; + } dlist_foreach(&util_av->ep_list, av_entry) { util_ep = container_of(av_entry, struct util_ep, av_entry); smr_ep = container_of(util_ep, struct smr_ep, util_ep); - smr_unmap_from_endpoint(smr_ep->region, id); if (smr_av->smr_map.num_peers > 0) smr_ep->region->max_sar_buf_per_peer = SMR_MAX_PEERS / diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index 141826b9bba..dfd219afcb8 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -891,13 +891,8 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) if (peer_smr->pid != (int) cmd->msg.hdr.data) { /* TODO track and update/complete in error any transfers * to or from old mapping - * - * TODO create smr_unmap_region - * this needs to close peer_smr->map->peers[idx].pid_fd - * This case will also return an unmapped region because the idx - * is valid but the region was unmapped */ - munmap(peer_smr, peer_smr->total_size); + smr_unmap_region(&smr_prov, ep->region->map, idx); smr_map_to_region(&smr_prov, ep->region->map, idx); peer_smr = smr_peer_region(ep->region, idx); } diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c index 2924ddaa6f2..b87a32be495 100644 --- a/prov/shm/src/smr_util.c +++ b/prov/shm/src/smr_util.c @@ -479,24 +479,66 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) return; } +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t peer_id) +{ + struct smr_region *peer_region; + struct smr_peer *peer; + struct util_ep *util_ep; + struct smr_ep *smr_ep; + struct smr_av *av; + int ret = 0; + + ofi_spin_lock(&map->lock); + peer_region = map->peers[peer_id].region; + if (!peer_region) + goto unlock; + + peer = &map->peers[peer_id]; + av = container_of(map, struct smr_av, smr_map); + dlist_foreach_container(&av->util_av.ep_list, struct util_ep, util_ep, + av_entry) { + smr_ep = container_of(util_ep, struct smr_ep, util_ep); + smr_unmap_from_endpoint(smr_ep->region, peer_id); + } + + if (map->flags & SMR_FLAG_HMEM_ENABLED) { + ret = ofi_hmem_host_unregister(peer_region); + if (ret) + FI_WARN(prov, FI_LOG_EP_CTRL, + "unable to unregister shm with iface\n"); + + if (peer->pid_fd != -1) { + close(peer->pid_fd); + peer->pid_fd = -1; + } + } + + munmap(peer_region, peer_region->total_size); + peer->region = NULL; + +unlock: + ofi_spin_unlock(&map->lock); +} + void smr_unmap_from_endpoint(struct smr_region *region, int64_t id) { struct smr_region *peer_smr; struct smr_peer_data *local_peers, *peer_peers; int64_t peer_id; - local_peers = smr_peer_data(region); if (region->map->peers[id].peer.id < 0) return; peer_smr = smr_peer_region(region, id); - peer_id = smr_peer_data(region)[id].addr.id; - + assert(peer_smr); peer_peers = smr_peer_data(peer_smr); + peer_id = smr_peer_data(region)[id].addr.id; peer_peers[peer_id].addr.id = -1; peer_peers[peer_id].name_sent = 0; + local_peers = smr_peer_data(region); ofi_xpmem_release(&local_peers[peer_id].xpmem); } @@ -544,40 +586,32 @@ int smr_map_add(const struct fi_provider *prov, struct smr_map *map, return FI_SUCCESS; } -void smr_map_del(struct smr_map *map, int64_t id) +int smr_map_unmap(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context) { - struct dlist_entry *entry; + struct smr_map *map = container_of(rbmap, struct smr_map, rbmap); + int64_t id = (uintptr_t) node->data; assert(id >= 0 && id < SMR_MAX_PEERS); - - pthread_mutex_lock(&ep_list_lock); - entry = dlist_find_first_match(&ep_name_list, smr_match_name, - smr_no_prefix(map->peers[id].peer.name)); - pthread_mutex_unlock(&ep_list_lock); - - ofi_spin_lock(&map->lock); - (void) ofi_rbmap_find_delete(&map->rbmap, - (void *) map->peers[id].peer.name); - + smr_unmap_region(&smr_prov, map, id); map->peers[id].fiaddr = FI_ADDR_NOTAVAIL; map->peers[id].peer.id = -1; map->num_peers--; - if (!map->peers[id].region) - goto unlock; + return FI_SUCCESS; +} - if (!entry) { - if (map->flags & SMR_FLAG_HMEM_ENABLED) { - if (map->peers[id].pid_fd != -1) - close(map->peers[id].pid_fd); +int smr_map_del(struct smr_map *map, struct ofi_rbnode *node) +{ + int64_t id = (uintptr_t) node->data; - (void) ofi_hmem_host_unregister(map->peers[id].region); - } - munmap(map->peers[id].region, map->peers[id].region->total_size); - map->peers[id].region = NULL; - } -unlock: + assert(id >= 0 && id < SMR_MAX_PEERS); + ofi_spin_lock(&map->lock); + smr_map_unmap(map->rbmap, node, NULL); + ofi_rbmap_delete(map->rbmap, node); ofi_spin_unlock(&map->lock); + + return FI_SUCCESS; } struct smr_region *smr_map_get(struct smr_map *map, int64_t id) diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h index c5bf8124873..b63dc15ee23 100644 --- a/prov/shm/src/smr_util.h +++ b/prov/shm/src/smr_util.h @@ -356,11 +356,16 @@ void smr_cleanup(void); int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, int64_t id); void smr_map_to_endpoint(struct smr_region *region, int64_t id); +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t id); void smr_unmap_from_endpoint(struct smr_region *region, int64_t id); void smr_exchange_all_peers(struct smr_region *region); int smr_map_add(const struct fi_provider *prov, struct smr_map *map, const char *name, int64_t *id); -void smr_map_del(struct smr_map *map, int64_t id); +int smr_map_unmap(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context); +int smr_map_del(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context); struct smr_region *smr_map_get(struct smr_map *map, int64_t id);