Skip to content

Commit

Permalink
Shutdown before dtor.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 25, 2024
1 parent 23787af commit 9a625aa
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 18 deletions.
4 changes: 4 additions & 0 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class FederatedComm : public HostComm {
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
~FederatedComm() override { stub_.reset(); }

[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
Expand Down
12 changes: 7 additions & 5 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: timeout_{timeout},
retry_{retry},
tracker_{host, port, -1},
task_id_{std::move(task_id)},
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
: timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}

Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
Expand Down Expand Up @@ -192,6 +188,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
std::int32_t retry, std::string task_id, StringView nccl_path)
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
if (!rc.OK()) {
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
Expand Down Expand Up @@ -309,6 +306,11 @@ RabitComm::~RabitComm() noexcept(false) {
if (n_bytes != scmd.size()) {
return Fail("Faled to send cmd.");
}

this->ResetState();
return Success();
} << [&] {
this->channels_.clear();
return Success();
};
}
Expand Down
4 changes: 4 additions & 0 deletions src/collective/comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class NCCLComm : public Comm {
auto rc = this->Stream().Sync(false);
return GetCUDAResult(rc);
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
};

class NCCLChannel : public Channel {
Expand Down
16 changes: 12 additions & 4 deletions src/collective/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ class Comm : public std::enable_shared_from_this<Comm> {
std::thread error_worker_;
std::string task_id_;
std::vector<std::shared_ptr<Channel>> channels_;
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout

void ResetState() {
this->world_ = -1;
this->rank_ = 0;
}

public:
Comm() = default;
Expand All @@ -78,7 +82,10 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
void Submit(Loop::Op op) const { loop_->Submit(op); }
void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
}
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }

[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
Expand All @@ -95,6 +102,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
auto rc = GetHostName(out);
return rc;
}
[[nodiscard]] virtual Result Shutdown() = 0;
};

/**
Expand All @@ -112,7 +120,7 @@ class RabitComm : public HostComm {

[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
[[nodiscard]] Result Shutdown();
[[nodiscard]] Result Shutdown() final;

public:
// bootstrapping construction.
Expand Down
2 changes: 2 additions & 0 deletions src/collective/comm_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ void GlobalCommGroupInit(Json config) {

void GlobalCommGroupFinalize() {
auto& sptr = GlobalCommGroup();
auto rc = sptr->Finalize();
sptr.reset();
SafeColl(rc);
}
} // namespace xgboost::collective
11 changes: 11 additions & 0 deletions src/collective/comm_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ class CommGroup {
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }

[[nodiscard]] Result Finalize() const {
return Success() << [this] {
if (gpu_comm_) {
return gpu_comm_->Shutdown();
}
return Success();
} << [&] {
return comm_->Shutdown();
};
}

[[nodiscard]] static CommGroup* Create(Json config);

[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
Expand Down
10 changes: 10 additions & 0 deletions src/collective/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ Result Loop::Stop() {
}
}

void Loop::Submit(Op op) {
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}

Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] { this->Process(); }};
Expand Down
10 changes: 1 addition & 9 deletions src/collective/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@ class Loop {
*/
Result Stop();

void Submit(Op op) {
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
void Submit(Op op);

/**
* @brief Block the event loop until all ops are finished. In the case of failure, this
Expand Down

0 comments on commit 9a625aa

Please sign in to comment.