diff --git a/src/components/tl/cuda/alltoallv/alltoallv_ce.c b/src/components/tl/cuda/alltoallv/alltoallv_ce.c index 39d9f18bfc..6ee71f6a0b 100644 --- a/src/components/tl/cuda/alltoallv/alltoallv_ce.c +++ b/src/components/tl/cuda/alltoallv/alltoallv_ce.c @@ -238,6 +238,47 @@ ucc_status_t ucc_tl_cuda_alltoallv_ce_post_copies(ucc_tl_cuda_task_t *task) return status; } +ucc_status_t ucc_tl_cuda_alltoallv_unmap(ucc_tl_cuda_task_t *task) +{ + ucc_tl_cuda_team_t *team = TASK_TEAM(task); + ucc_rank_t i, dst; + volatile ucc_tl_cuda_sync_t *peer_sync; + ucc_tl_cuda_cache_t *cache; + ucc_status_t status; + + for (i = 0; i < UCC_TL_TEAM_SIZE(team); i++) { + if (i == UCC_TL_TEAM_RANK(team) || + !ucc_tl_cuda_team_topo_is_direct(&team->super, team->topo, + UCC_TL_TEAM_RANK(team), i)) { + continue; + } + peer_sync = TASK_SYNC(task, i); + cache = ucc_tl_cuda_get_cache(team, i); + + status = ucc_tl_cuda_unmap_memhandle( + (uintptr_t)peer_sync->mem_info_src.ptr, + task->alltoallv_ce.peer_map_addr_src[i], cache, 0); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } + + for (i = 0; i < team->topo->num_proxies; i++) { + dst = team->topo->proxies[i].dst; + peer_sync = TASK_SYNC(task, dst); + cache = ucc_tl_cuda_get_cache(team, dst); + + status = ucc_tl_cuda_unmap_memhandle( + (uintptr_t)peer_sync->mem_info_dst.ptr, + task->alltoallv_ce.peer_map_addr_dst[dst], cache, 0); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } + + return UCC_OK; +} + void ucc_tl_cuda_alltoallv_ce_progress(ucc_coll_task_t *coll_task) { ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t); @@ -301,11 +342,14 @@ void ucc_tl_cuda_alltoallv_ce_progress(ucc_coll_task_t *coll_task) ucc_assert(task->alltoallv_ce.stage == ALLTOALL_CE_STAGE_BAR); break; } - task->super.status = - ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar); - if (task->super.status == UCC_OK) { + + status = ucc_tl_cuda_shm_barrier_test(UCC_TL_TEAM_RANK(team), task->bar); + if (status == UCC_OK) { + status = ucc_tl_cuda_alltoallv_unmap(task); ucc_tl_cuda_put_sync(task); } + + task->super.status = status; } ucc_status_t ucc_tl_cuda_alltoallv_ce_start(ucc_coll_task_t *coll_task) diff --git a/src/components/tl/cuda/tl_cuda_cache.c b/src/components/tl/cuda/tl_cuda_cache.c index ff8539d648..dd2db650fa 100644 --- a/src/components/tl/cuda/tl_cuda_cache.c +++ b/src/components/tl/cuda/tl_cuda_cache.c @@ -291,8 +291,13 @@ ucc_tl_cuda_map_memhandle(const void *d_ptr, size_t size, } ucc_status_t ucc_tl_cuda_unmap_memhandle(uintptr_t d_bptr, void *mapped_addr, - ucc_tl_cuda_cache_t *cache) + ucc_tl_cuda_cache_t *cache, int force) { + + if ((d_bptr == 0) || (mapped_addr == 0)) { + return UCC_OK; + } + #if ENABLE_CACHE ucs_pgt_region_t *pgt_region; ucc_tl_cuda_cache_region_t *region; @@ -300,12 +305,22 @@ ucc_status_t ucc_tl_cuda_unmap_memhandle(uintptr_t d_bptr, void *mapped_addr, /* use write lock because cache maybe modified */ pthread_rwlock_wrlock(&cache->lock); pgt_region = ucs_pgtable_lookup(&cache->pgtable, d_bptr); + + ucc_debug("%s: tl_cuda unmap addr:%p region:" + UCS_PGT_REGION_FMT, cache->name, (void*)d_bptr, + UCS_PGT_REGION_ARG(pgt_region)); + ucc_assert(pgt_region != NULL); region = ucc_derived_of(pgt_region, ucc_tl_cuda_cache_region_t); ucc_assert(region->refcount >= 1); region->refcount--; + if ((region->refcount == 0 ) && (force == 1)) { + ucs_pgtable_remove(&cache->pgtable, ®ion->super); + CUDA_FUNC(cudaIpcCloseMemHandle(mapped_addr)); + } + pthread_rwlock_unlock(&cache->lock); #else CUDA_FUNC(cudaIpcCloseMemHandle(mapped_addr)); diff --git a/src/components/tl/cuda/tl_cuda_cache.h b/src/components/tl/cuda/tl_cuda_cache.h index fed05e126a..46d394b452 100644 --- a/src/components/tl/cuda/tl_cuda_cache.h +++ b/src/components/tl/cuda/tl_cuda_cache.h @@ -39,7 +39,7 @@ ucc_status_t ucc_tl_cuda_map_memhandle(const void *dptr, size_t size, ucc_tl_cuda_cache_t *cache); ucc_status_t ucc_tl_cuda_unmap_memhandle(uintptr_t d_bptr, void *mapped_addr, - ucc_tl_cuda_cache_t *cache); + ucc_tl_cuda_cache_t *cache, int force); ucc_tl_cuda_cache_t* ucc_tl_cuda_get_cache(ucc_tl_cuda_team_t *team, ucc_rank_t rank); diff --git a/src/components/tl/cuda/tl_cuda_team.c b/src/components/tl/cuda/tl_cuda_team.c index 4a7e5478c7..12e83c60cb 100644 --- a/src/components/tl/cuda/tl_cuda_team.c +++ b/src/components/tl/cuda/tl_cuda_team.c @@ -185,7 +185,7 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_cuda_team_t) if (self->scratch.rem[i]) { ucc_tl_cuda_unmap_memhandle((uintptr_t)self->scratch.rem_info[i].ptr, self->scratch.rem[i], - ucc_tl_cuda_get_cache(self, i)); + ucc_tl_cuda_get_cache(self, i), 1); } }