diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 1e57fb736b7c2..87a163b2cb4fa 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -14,6 +14,7 @@ #include "paddle/phi/backends/device_manager.h" #include "paddle/phi/common/complex.h" +#include "paddle/phi/core/distributed/xccl_comm_context.h" #if !defined(_WIN32) #include @@ -699,6 +700,9 @@ DeviceManager& DeviceManager::Instance() { void DeviceManager::Release() { event::Event::ReleaseAll(); stream::Stream::ReleaseAll(); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + phi::distributed::XCCLCommContext::ReleaseAll(); +#endif Instance().device_map_.clear(); Instance().device_impl_map_.clear(); } diff --git a/paddle/phi/core/distributed/xccl_comm_context.cc b/paddle/phi/core/distributed/xccl_comm_context.cc index ba7e24ab06b9e..3e3608e4d88a5 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.cc +++ b/paddle/phi/core/distributed/xccl_comm_context.cc @@ -14,6 +14,8 @@ #include "paddle/phi/core/distributed/xccl_comm_context.h" +#include + #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" @@ -25,6 +27,29 @@ namespace phi { namespace distributed { +std::list g_xccl_comm_contexts; +std::mutex g_xccl_comm_contexts_mutex; + +void XCCLCommContext::ReleaseAll() { + std::unique_lock lock(g_xccl_comm_contexts_mutex); + for (auto xccl_comm_ctx : g_xccl_comm_contexts) { + phi::DeviceManager::CCLDestroyComm(xccl_comm_ctx->GetDeviceType(), + xccl_comm_ctx->GetXcclComm()); + xccl_comm_ctx->xccl_comm_ = nullptr; + } + g_xccl_comm_contexts.clear(); +} + +XCCLCommContext::~XCCLCommContext() { + std::unique_lock lock(g_xccl_comm_contexts_mutex); + if (phi::DeviceManager::HasDeviceType(this->GetDeviceType()) && + xccl_comm_ != nullptr) { + phi::DeviceManager::CCLDestroyComm(this->GetDeviceType(), xccl_comm_); + xccl_comm_ = nullptr; + } + g_xccl_comm_contexts.remove(this); +} + XCCLCommContext::XCCLCommContext(const phi::Place& place, int rank, int size, @@ -38,6 +63,8 @@ XCCLCommContext::XCCLCommContext(const phi::Place& place, &xccl_comm_); stream_ = std::make_shared(); stream_->Init(place_); + std::unique_lock lock(g_xccl_comm_contexts_mutex); + g_xccl_comm_contexts.push_back(this); } void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, diff --git a/paddle/phi/core/distributed/xccl_comm_context.h b/paddle/phi/core/distributed/xccl_comm_context.h index 0c253eb925bb4..8cdc7e4153d76 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.h +++ b/paddle/phi/core/distributed/xccl_comm_context.h @@ -28,6 +28,9 @@ class XCCLCommContext final : public CommContext { int rank, int size, const ccl::CCLRootId& xccl_id); + ~XCCLCommContext(); + + static void ReleaseAll(); ccl::CCLComm GetXcclComm() const { return xccl_comm_; }