diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index d847d8655aa0..b39e1878a8ea 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -50,6 +50,10 @@ class FederatedComm : public HostComm { std::int32_t rank) { this->Init(host, port, world, rank, {}, {}, {}); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } ~FederatedComm() override { stub_.reset(); } [[nodiscard]] std::shared_ptr Chan(std::int32_t) const override { diff --git a/src/collective/comm.cc b/src/collective/comm.cc index bc8cf61d96d9..088f845510d7 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -22,11 +22,7 @@ namespace xgboost::collective { Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry, std::string task_id) - : timeout_{timeout}, - retry_{retry}, - tracker_{host, port, -1}, - task_id_{std::move(task_id)}, - loop_{std::shared_ptr{new Loop{timeout}}} {} + : timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {} Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, std::string const& task_id, TCPSocket* out, std::int32_t rank, @@ -192,6 +188,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se std::int32_t retry, std::string task_id, StringView nccl_path) : HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, nccl_path_{std::move(nccl_path)} { + loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT auto rc = this->Bootstrap(timeout_, retry_, task_id_); if (!rc.OK()) { SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); @@ -309,6 +306,11 @@ RabitComm::~RabitComm() noexcept(false) { if (n_bytes != scmd.size()) { return Fail("Faled to send cmd."); } + + this->ResetState(); + return Success(); + } << [&] { + this->channels_.clear(); return Success(); }; } diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index a818d95f8134..4add9ca612e0 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -50,6 +50,10 @@ class NCCLComm : public Comm { auto rc = this->Stream().Sync(false); return GetCUDAResult(rc); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } }; class NCCLChannel : public Channel { diff --git a/src/collective/comm.h b/src/collective/comm.h index 4b948beb027e..6ad5bc5c1a6f 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -54,8 +54,12 @@ class Comm : public std::enable_shared_from_this { std::thread error_worker_; std::string task_id_; std::vector> channels_; - std::shared_ptr loop_{new Loop{std::chrono::seconds{ - DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout + std::shared_ptr loop_{nullptr}; // fixme: require federated comm to have a timeout + + void ResetState() { + this->world_ = -1; + this->rank_ = 0; + } public: Comm() = default; @@ -78,7 +82,10 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] auto Rank() const { return rank_; } [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } [[nodiscard]] bool IsDistributed() const { return world_ != -1; } - void Submit(Loop::Op op) const { loop_->Submit(op); } + void Submit(Loop::Op op) const { + CHECK(loop_); + loop_->Submit(op); + } [[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { @@ -95,6 +102,7 @@ class Comm : public std::enable_shared_from_this { auto rc = GetHostName(out); return rc; } + [[nodiscard]] virtual Result Shutdown() = 0; }; /** @@ -112,7 +120,7 @@ class RabitComm : public HostComm { [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id); - [[nodiscard]] Result Shutdown(); + [[nodiscard]] Result Shutdown() final; public: // bootstrapping construction. diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index 2936a14914ec..7408882f6d2a 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -116,6 +116,8 @@ void GlobalCommGroupInit(Json config) { void GlobalCommGroupFinalize() { auto& sptr = GlobalCommGroup(); + auto rc = sptr->Finalize(); sptr.reset(); + SafeColl(rc); } } // namespace xgboost::collective diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index 1f403ae6b592..61a58ba56bed 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -34,6 +34,17 @@ class CommGroup { [[nodiscard]] auto Rank() const { return comm_->Rank(); } [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + [[nodiscard]] Result Finalize() const { + return Success() << [this] { + if (gpu_comm_) { + return gpu_comm_->Shutdown(); + } + return Success(); + } << [&] { + return comm_->Shutdown(); + }; + } + [[nodiscard]] static CommGroup* Create(Json config); [[nodiscard]] std::shared_ptr Backend(DeviceOrd device) const; diff --git a/src/collective/loop.cc b/src/collective/loop.cc index b51749fcdad5..cd4859f333a8 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -243,6 +243,16 @@ Result Loop::Stop() { } } +void Loop::Submit(Op op) { + std::unique_lock lock{mu_}; + if (op.code != Op::kBlock) { + CHECK_NE(op.n, 0); + } + queue_.push(op); + lock.unlock(); + cv_.notify_one(); +} + Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { timer_.Init(__func__); worker_ = std::thread{[this] { this->Process(); }}; diff --git a/src/collective/loop.h b/src/collective/loop.h index 5e9f38c2933f..5405f6b13e1d 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -64,15 +64,7 @@ class Loop { */ Result Stop(); - void Submit(Op op) { - std::unique_lock lock{mu_}; - if (op.code != Op::kBlock) { - CHECK_NE(op.n, 0); - } - queue_.push(op); - lock.unlock(); - cv_.notify_one(); - } + void Submit(Op op); /** * @brief Block the event loop until all ops are finished. In the case of failure, this