From 8e55c2697e9f5720735e62bd44e2e7a2c75f0b0c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 12:57:09 +0800 Subject: [PATCH 01/11] [coll] Implement a new tracker and a communicator. The new tracker and communicators communicates through the use of JSON documents. Along with which, communicators are aware of each others. --- include/xgboost/collective/socket.h | 33 ++- src/collective/allgather.cc | 42 ++++ src/collective/allgather.h | 23 ++ src/collective/comm.cc | 313 +++++++++++++++++++++++++++ src/collective/comm.h | 148 +++++++++++++ src/collective/protocol.h | 213 ++++++++++++++++++ src/collective/tracker.cc | 222 +++++++++++++++++++ src/collective/tracker.h | 126 ++++++++++- tests/cpp/collective/test_comm.cc | 47 ++++ tests/cpp/collective/test_socket.cc | 2 +- tests/cpp/collective/test_tracker.cc | 65 +++++- tests/cpp/collective/test_worker.h | 110 ++++++++++ 12 files changed, 1330 insertions(+), 14 deletions(-) create mode 100644 src/collective/allgather.cc create mode 100644 src/collective/allgather.h create mode 100644 src/collective/comm.cc create mode 100644 src/collective/comm.h create mode 100644 src/collective/protocol.h create mode 100644 tests/cpp/collective/test_comm.cc create mode 100644 tests/cpp/collective/test_worker.h diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index a16dd05c07d8..0e8c5faeab57 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -380,11 +380,18 @@ class TCPSocket { } [[nodiscard]] bool NonBlocking() const { return non_blocking_; } [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) { - timeval tv; + // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout +#if defined(_WIN32) + DWORD tv = timeout.count() * 1000; + auto rc = setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), + sizeof(tv)); +#else + struct timeval tv; tv.tv_sec = timeout.count(); tv.tv_usec = 0; auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif if (rc != 0) { return system::FailWithCode("Failed to set timeout on recv."); } @@ -468,7 +475,7 @@ class TCPSocket { /** * \brief Bind socket to INADDR_ANY, return the port selected by the OS. */ - in_port_t BindHost() { + [[nodiscard]] in_port_t BindHost() { if (Domain() == SockDomain::kV6) { auto addr = SockAddrV6::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); @@ -539,7 +546,7 @@ class TCPSocket { /** * \brief Send data, without error then all data should be sent. */ - auto SendAll(void const *buf, std::size_t len) { + [[nodiscard]] auto SendAll(void const *buf, std::size_t len) { char const *_buf = reinterpret_cast(buf); std::size_t ndone = 0; while (ndone < len) { @@ -558,7 +565,7 @@ class TCPSocket { /** * \brief Receive data, without error then all data should be received. */ - auto RecvAll(void *buf, std::size_t len) { + [[nodiscard]] auto RecvAll(void *buf, std::size_t len) { char *_buf = reinterpret_cast(buf); std::size_t ndone = 0; while (ndone < len) { @@ -634,6 +641,24 @@ class TCPSocket { socket.domain_ = domain; #endif // defined(__APPLE__) return socket; +#endif // defined(xgboost_IS_MINGW) + } + + static TCPSocket *CreatePtr(SockDomain domain) { +#if defined(xgboost_IS_MINGW) + MingWError(); + return nullptr; +#else + auto fd = socket(static_cast(domain), SOCK_STREAM, 0); + if (fd == InvalidSocket()) { + system::ThrowAtError("socket"); + } + auto socket = new TCPSocket{fd}; + +#if defined(__APPLE__) + socket->domain_ = domain; +#endif // defined(__APPLE__) + return socket; #endif // defined(xgboost_IS_MINGW) } }; diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc new file mode 100644 index 000000000000..9b64d36114c8 --- /dev/null +++ b/src/collective/allgather.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "allgather.h" + +#include // for min +#include // for int8_t +#include // for shared_ptr + +#include "comm.h" +#include "rabit/internal/socket.h" +#include "xgboost/collective/socket.h" + +namespace xgboost::collective::cpu_impl { +Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, + std::int32_t worker_off, std::shared_ptr prev_ch, + std::shared_ptr next_ch) { + auto world = comm.World(); + auto rank = comm.Rank(); + CHECK_LT(worker_off, world); + + for (std::int32_t r = 0; r < world; ++r) { + auto send_rank = (rank + world - r + worker_off) % world; + auto send_off = send_rank * segment_size; + send_off = std::min(send_off, data.size_bytes()); + auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); + next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); + + auto recv_rank = (rank + world - r - 1 + worker_off) % world; + auto recv_off = recv_rank * segment_size; + recv_off = std::min(recv_off, data.size_bytes()); + auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); + prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); + auto rc = prev_ch->Block(); + if (!rc.OK()) { + return rc; + } + } + + return Success(); +} +} // namespace xgboost::collective::cpu_impl diff --git a/src/collective/allgather.h b/src/collective/allgather.h new file mode 100644 index 000000000000..31a9a36b31e7 --- /dev/null +++ b/src/collective/allgather.h @@ -0,0 +1,23 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for size_t +#include // for int32_t +#include // for shared_ptr + +#include "comm.h" // for Comm, Channel +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +namespace cpu_impl { +/** + * @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off + * = 1, then it owns the third segment. + */ +[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, + std::size_t segment_size, std::int32_t worker_off, + std::shared_ptr prev_ch, + std::shared_ptr next_ch); +} // namespace cpu_impl +} // namespace xgboost::collective diff --git a/src/collective/comm.cc b/src/collective/comm.cc new file mode 100644 index 000000000000..b1884d58eb2f --- /dev/null +++ b/src/collective/comm.cc @@ -0,0 +1,313 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "comm.h" + +#include // for unique_lock +#include // for move, forward + +#include "../c_api/c_api_utils.h" +#include "../common/common.h" +#include "allgather.h" +#include "protocol.h" // for kMagic +#include "tracker.h" +#include "xgboost/json.h" // for Json, Object + +namespace xgboost::collective { +namespace { +// We don't have monad, a simple helper would do. +template +Result operator<<(Result&& r, Fn&& fn) { + if (!r.OK()) { + return std::forward(r); + } + return fn(); +} +} // namespace + +Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t retry, std::string task_id) + : timeout_{timeout}, + retry_{retry}, + tracker_{host, port, -1}, + task_id_{std::move(task_id)}, + loop_{std::make_shared(timeout)} {} + +Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, + std::string const& task_id, xgboost::collective::TCPSocket* out, + std::int32_t rank, std::int32_t world) { + // get information from tracker + CHECK(!info.host.empty()); + auto rc = Connect(info.host, info.port, retry, timeout, out); + if (!rc.OK()) { + return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); + } + + TCPSocket& tracker = *out; + proto::Magic magic; + proto::Connect conn; + + return std::move(rc) + << [&] { return tracker.NonBlocking(false); } + << [&] { return tracker.RecvTimeout(timeout); } + << [&] { return magic.Verify(&tracker); } + << [&] { return conn.WorkerSend(&tracker, world, rank, task_id); }; +} + +[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const { + return ConnectTrackerImpl(this->TrackerInfo(), this->Timeout(), this->retry_, this->task_id_, out, + this->Rank(), this->World()); +} + +[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport, + proto::PeerInfo ninfo, std::chrono::seconds timeout, + std::int32_t retry, + std::vector>* out_workers) { + auto next = std::make_shared(); + auto rc = Success() << [&] { + auto rc = Connect(StringView{ninfo.host}, ninfo.port, retry, timeout, next.get()); + if (!rc.OK()) { + return Fail("Bootstrap failed to connect to ring next.", std::move(rc)); + } + return rc; + } << [&] { return next->NonBlocking(true); }; + if (!rc.OK()) { + return rc; + } + + auto prev = std::make_shared(); + SockAddrV4 addr; + rc = std::move(rc) << [&] { return listener->Accept(prev.get(), &addr); } + << [&] { return prev->NonBlocking(true); }; + if (!rc.OK()) { + return rc; + } + + // exchange host name and port + std::vector buffer(HOST_NAME_MAX * comm.World(), 0); + auto s_buffer = common::Span{buffer.data(), buffer.size()}; + auto next_host = s_buffer.subspan(HOST_NAME_MAX * comm.Rank(), HOST_NAME_MAX); + if (next_host.size() < ninfo.host.size()) { + return Fail("Got an invalid host name."); + } + std::copy(ninfo.host.cbegin(), ninfo.host.cend(), next_host.begin()); + + auto prev_ch = std::make_shared(comm, prev); + auto next_ch = std::make_shared(comm, next); + rc = cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch); + if (!rc.OK()) { + return Fail("Failed to get host names from peers.", std::move(rc)); + } + for (auto ch : {prev_ch, next_ch}) { + rc = ch->Block(); + if (!rc.OK()) { + return rc; + } + } + + std::vector peers_port(comm.World(), -1); + peers_port[comm.Rank()] = ninfo.port; + auto s_ports = common::Span{reinterpret_cast(peers_port.data()), + peers_port.size() * sizeof(ninfo.port)}; + rc = cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch); + if (!rc.OK()) { + return Fail("Failed to get the port from peers.", std::move(rc)); + } + for (auto ch : {prev_ch, next_ch}) { + rc = ch->Block(); + if (!rc.OK()) { + return rc; + } + } + + std::vector peers(comm.World()); + for (auto r = 0; r < comm.World(); ++r) { + auto nhost = s_buffer.subspan(HOST_NAME_MAX * r, HOST_NAME_MAX); + auto nport = peers_port[r]; + auto nrank = BootstrapNext(r, comm.World()); + + peers[nrank] = {std::string{reinterpret_cast(nhost.data())}, nport, nrank}; + } + CHECK_EQ(peers[comm.Rank()].port, lport); + for (auto const& p : peers) { + CHECK_NE(p.port, -1); + } + + std::vector>& workers = *out_workers; + workers.resize(comm.World()); + + for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) { + auto const& peer = peers[r]; + std::shared_ptr worker{TCPSocket::CreatePtr(comm.Domain())}; + rc = std::move(rc) + << [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); } + << [&] { return worker->RecvTimeout(timeout); }; + if (!rc.OK()) { + return rc; + } + + auto rank = comm.Rank(); + auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank())); + if (n_bytes != sizeof(comm.Rank())) { + return Fail("Failed to send rank."); + } + workers[r] = std::move(worker); + } + + for (std::int32_t r = 0; r < comm.Rank(); ++r) { + SockAddrV4 addr; + auto peer = std::shared_ptr(TCPSocket::CreatePtr(comm.Domain())); + rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); } + << [&] { return peer->RecvTimeout(timeout); }; + if (!rc.OK()) { + return rc; + } + std::int32_t rank; + auto n_bytes = peer->RecvAll(&rank, sizeof(rank)); + if (n_bytes != sizeof(comm.Rank())) { + return Fail("Failed to recv rank."); + } + workers[rank] = std::move(peer); + } + + for (std::int32_t r = 0; r < comm.World(); ++r) { + if (r == comm.Rank()) { + continue; + } + CHECK(workers[r]); + } + + return Success(); +} + +RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t retry, std::string task_id) + : Comm{std::move(host), port, timeout, retry, std::move(task_id)} { + auto rc = this->Bootstrap(timeout_, retry_, task_id_); + CHECK(rc.OK()) << rc.Report(); +} + +[[nodiscard]] Result RabitComm::Bootstrap(std::chrono::seconds timeout, std::int32_t retry, + std::string task_id) { + TCPSocket tracker; + std::int32_t world{-1}; + auto rc = ConnectTrackerImpl(this->TrackerInfo(), timeout, retry, task_id, &tracker, this->Rank(), + world); + if (!rc.OK()) { + return Fail("Bootstrap failed.", std::move(rc)); + } + + this->domain_ = tracker.Domain(); + + // Start command + TCPSocket listener = TCPSocket::Create(tracker.Domain()); + std::int32_t lport = listener.BindHost(); + listener.Listen(); + + // create worker for listening to error notice. + auto domain = tracker.Domain(); + std::shared_ptr error_sock{TCPSocket::CreatePtr(domain)}; + auto eport = error_sock->BindHost(); + error_sock->Listen(); + error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] { + auto conn = error_sock->Accept(); + LOG(WARNING) << "Another worker is running into error."; + std::string scmd; + conn.Recv(&scmd); + auto jcmd = Json::Load(scmd); + auto rc = this->Shutdown(); + if (!rc.OK()) { + LOG(WARNING) << "Fail to shutdown worker:" << rc.Report(); + } +#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 + exit(-1); +#else + LOG(FATAL) << rc.Report(); +#endif + }}; + error_worker_.detach(); + + proto::Start start; + rc = std::move(rc) << [&] { return start.WorkerSend(lport, &tracker, eport); } + << [&] { return start.WorkerRecv(&tracker, &world); }; + if (!rc.OK()) { + return rc; + } + this->world_ = world; + + // get ring neighbors + std::string snext; + tracker.Recv(&snext); + auto jnext = Json::Load(StringView{snext}); + + proto::PeerInfo ninfo{jnext}; + + // get the rank of this worker + this->rank_ = BootstrapPrev(ninfo.rank, world); + this->tracker_.rank = rank_; + + std::vector> workers; + rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers); + + CHECK(this->channels_.empty()); + for (auto& w : workers) { + if (w) { + w->SetNoDelay(); + rc = w->NonBlocking(true); + } + if (!rc.OK()) { + return rc; + } + this->channels_.emplace_back(std::make_shared(*this, w)); + } + return rc; +} + +RabitComm::~RabitComm() noexcept(false) { + if (!IsDistributed()) { + return; + } + auto rc = this->Shutdown(); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + } +} + +[[nodiscard]] Result RabitComm::Shutdown() { + TCPSocket tracker; + auto rc = Success() << [&] { + return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); + } << [&] { return this->Block(); }; + if (!rc.OK()) { + return rc; + } + + Json jcmd{Object{}}; + jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; + auto scmd = Json::Dump(jcmd); + tracker.Send(scmd); + + return Success(); +} + +[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const { + TCPSocket out; + proto::Print print; + auto rc = Success() << [&] { return this->ConnectTracker(&out); } + << [&] { return print.WorkerSend(&out, msg); }; + if (!rc.OK()) { + return Fail("Logging to tracker failed.", std::move(rc)); + } + return rc; +} + +[[nodiscard]] Result RabitComm::SignalError(Result const& res) { + TCPSocket out; + auto rc = this->ConnectTracker(&out); + if (!rc.OK()) { + return Fail("Logging to tracker failed.", std::move(rc)); + } + proto::ErrorCMD cmd; + return cmd.WorkerSend(&out, res); +} +} // namespace xgboost::collective diff --git a/src/collective/comm.h b/src/collective/comm.h new file mode 100644 index 000000000000..2756bebb52b1 --- /dev/null +++ b/src/collective/comm.h @@ -0,0 +1,148 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for condition_variable +#include // for size_t +#include // for int32_t +#include // for mutex +#include // for queue +#include // for thread +#include // for remove_const_t + +#include "../common/timer.h" +#include "loop.h" // for Loop +#include "protocol.h" // for PeerInfo +#include "xgboost/collective/result.h" // for Result +#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { + +inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min +inline constexpr std::int32_t DefaultRetry() { return 3; } + +inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) { + auto nrank = (r + world + 1) % world; + return nrank; +} + +inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) { + auto nrank = (r + world - 1) % world; + return nrank; +} + +class Channel; + +class Comm { + protected: + std::int32_t world_{1}; + std::int32_t rank_{0}; + std::chrono::seconds timeout_{DefaultTimeoutSec()}; + std::int32_t retry_{DefaultRetry()}; + + proto::PeerInfo tracker_; + SockDomain domain_{SockDomain::kV4}; + std::thread error_worker_; + std::string task_id_; + std::vector> channels_; + std::shared_ptr loop_{new Loop{std::chrono::seconds{ + DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout + + public: + Comm() = default; + Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry, + std::string task_id); + virtual ~Comm() noexcept(false) {} // NOLINT + + Comm(Comm const& that) = delete; + Comm& operator=(Comm const& that) = delete; + Comm(Comm&& that) = delete; + Comm& operator=(Comm&& that) = delete; + + [[nodiscard]] auto TrackerInfo() const { return tracker_; } + [[nodiscard]] Result ConnectTracker(TCPSocket* out) const; + [[nodiscard]] auto Domain() const { return domain_; } + [[nodiscard]] auto Timeout() const { return timeout_; } + + [[nodiscard]] auto Rank() const { return rank_; } + [[nodiscard]] auto World() const { return world_; } + [[nodiscard]] bool IsDistributed() const { return World() > 1; } + void Submit(Loop::Op op) const { loop_->Submit(op); } + [[nodiscard]] Result Block() const { return loop_->Block(); } + + [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { + return channels_.at(rank); + } + [[nodiscard]] virtual bool IsFederated() const = 0; + [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; + + [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } +}; + +class RabitComm : public Comm { + [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, + std::string task_id); + [[nodiscard]] Result Shutdown(); + + public: + // bootstrapping construction. + RabitComm() = default; + // ctor for testing where environment is known. + RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t retry, std::string task_id); + ~RabitComm() noexcept(false) override; + + [[nodiscard]] bool IsFederated() const override { return false; } + [[nodiscard]] Result LogTracker(std::string msg) const override; + + [[nodiscard]] Result SignalError(Result const&) override; +}; + +class Channel { + std::shared_ptr sock_{nullptr}; + Result rc_; + Comm const& comm_; + + public: + explicit Channel(Comm const& comm, std::shared_ptr sock) + : sock_{std::move(sock)}, comm_{comm} {} + + void SendAll(std::int8_t const* ptr, std::size_t n) { + Loop::Op op{Loop::Op::kWrite, comm_.Rank(), const_cast(ptr), n, sock_.get(), 0}; + CHECK(sock_.get()); + comm_.Submit(std::move(op)); + } + void SendAll(common::Span data) { + this->SendAll(data.data(), data.size_bytes()); + } + + void RecvAll(std::int8_t* ptr, std::size_t n) { + Loop::Op op{Loop::Op::kRead, comm_.Rank(), ptr, n, sock_.get(), 0}; + CHECK(sock_.get()); + comm_.Submit(std::move(op)); + } + void RecvAll(common::Span data) { this->RecvAll(data.data(), data.size_bytes()); } + + [[nodiscard]] auto Socket() const { return sock_; } + [[nodiscard]] Result Block() { return comm_.Block(); } +}; + +enum class Op { kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 }; + +template , + std::add_const_t, std::int8_t>> +common::Span EraseType(common::Span data) { + auto n_total_bytes = data.size_bytes(); + auto erased = common::Span{reinterpret_cast>(data.data()), n_total_bytes}; + return erased; +} + +template +common::Span RestoreType(common::Span data) { + static_assert(std::is_same_v, std::int8_t>); + auto n_total_bytes = data.size_bytes(); + auto restored = common::Span{reinterpret_cast(data.data()), n_total_bytes / sizeof(T)}; + return restored; +} +} // namespace xgboost::collective diff --git a/src/collective/protocol.h b/src/collective/protocol.h new file mode 100644 index 000000000000..5ed72fafacc4 --- /dev/null +++ b/src/collective/protocol.h @@ -0,0 +1,213 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for int32_t +#include // for string + +#include "xgboost/collective/result.h" // for Result +#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/json.h" // for Json + +namespace xgboost::collective::proto { +struct PeerInfo { + std::string host; + std::int32_t port{-1}; + std::int32_t rank{-1}; + + PeerInfo() = default; + PeerInfo(std::string host, std::int32_t port, std::int32_t rank) + : host{std::move(host)}, port{port}, rank{rank} {} + + explicit PeerInfo(Json const& peer) + : host{get(peer["host"])}, + port{static_cast(get(peer["port"]))}, + rank{static_cast(get(peer["rank"]))} {} + + [[nodiscard]] Json ToJson() const { + Json info{Object{}}; + info["rank"] = rank; + info["host"] = String{host}; + info["port"] = Integer{port}; + return info; + } + + [[nodiscard]] auto HostPort() const { return host + ":" + std::to_string(this->port); } +}; + +struct Magic { + static constexpr std::int32_t kMagic = 0xff99; + + [[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) { + std::int32_t magic{kMagic}; + auto n_bytes = p_sock->SendAll(&magic, sizeof(magic)); + if (n_bytes != sizeof(magic)) { + return Fail("Failed to verify."); + } + + magic = 0; + n_bytes = p_sock->RecvAll(&magic, sizeof(magic)); + if (n_bytes != sizeof(magic)) { + return Fail("Failed to verify."); + } + if (magic != kMagic) { + return xgboost::collective::Fail("Invalid verification number."); + } + return Success(); + } +}; + +enum class CMD : std::int32_t { + kInvalid = 0, + kStart = 1, + kShutdown = 2, + kError = 3, + kPrint = 4, +}; + +struct Connect { + [[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::int32_t world, std::int32_t rank, + std::string task_id) const { + Json jinit{Object{}}; + jinit["world_size"] = Integer{world}; + jinit["rank"] = Integer{rank}; + jinit["task_id"] = String{task_id}; + std::string msg; + Json::Dump(jinit, &msg); + auto n_bytes = tracker->Send(msg); + if (n_bytes != msg.size()) { + return Fail("Failed to send init command from worker."); + } + return Success(); + } + [[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank, + std::string* task_id) const { + std::string init; + sock->Recv(&init); + auto jinit = Json::Load(StringView{init}); + *world = get(jinit["world_size"]); + *rank = get(jinit["rank"]); + *task_id = get(jinit["task_id"]); + return Success(); + } +}; + +class Start { + private: + [[nodiscard]] Result TrackerSend(std::int32_t world, TCPSocket* worker) const { + Json jcmd{Object{}}; + jcmd["world_size"] = Integer{world}; + auto scmd = Json::Dump(jcmd); + auto n_bytes = worker->Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Failed to send init command from tracker."); + } + return Success(); + } + + public: + [[nodiscard]] Result WorkerSend(std::int32_t lport, TCPSocket* tracker, + std::int32_t eport) const { + Json jcmd{Object{}}; + jcmd["cmd"] = Integer{static_cast(CMD::kStart)}; + jcmd["port"] = Integer{lport}; + jcmd["error_port"] = Integer{eport}; + auto scmd = Json::Dump(jcmd); + auto n_bytes = tracker->Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Failed to send init command from worker."); + } + return Success(); + } + [[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const { + std::string scmd; + auto n_bytes = tracker->Recv(&scmd); + if (n_bytes <= 0) { + return Fail("Failed to recv init command from tracker."); + } + auto jcmd = Json::Load(scmd); + auto world = get(jcmd["world_size"]); + if (world <= 0) { + return Fail("Invalid world size."); + } + *p_world = world; + return Success(); + } + [[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world, + std::int32_t* p_port, TCPSocket* p_sock, + std::int32_t* eport) const { + *p_port = get(jcmd["port"]); + if (*p_port <= 0) { + return Fail("Invalid port."); + } + if (*recv_world != -1) { + return Fail("Invalid initialization sequence."); + } + *recv_world = world; + *eport = get(jcmd["error_port"]); + return TrackerSend(world, p_sock); + } +}; + +struct Print { + [[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const { + Json jcmd{Object{}}; + jcmd["cmd"] = Integer{static_cast(CMD::kPrint)}; + jcmd["msg"] = String{std::move(msg)}; + auto scmd = Json::Dump(jcmd); + auto n_bytes = tracker->Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Failed to send print command from worker."); + } + return Success(); + } + [[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg) const { + if (!IsA(jcmd["msg"])) { + return Fail("Invalid print command."); + } + auto msg = get(jcmd["msg"]); + *p_msg = msg; + return Success(); + } +}; + +struct ErrorCMD { + [[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const { + auto msg = res.Report(); + auto code = res.Code().value(); + Json jcmd{Object{}}; + jcmd["msg"] = String{std::move(msg)}; + jcmd["code"] = Integer{code}; + jcmd["cmd"] = Integer{static_cast(CMD::kError)}; + auto scmd = Json::Dump(jcmd); + auto n_bytes = tracker->Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Failed to send error command from worker."); + } + return Success(); + } + [[nodiscard]] Result TrackerHandle(Json jcmd, std::string* p_msg, int* p_code) const { + if (!IsA(jcmd["msg"]) || !IsA(jcmd["code"])) { + return Fail("Invalid error command."); + } + auto msg = get(jcmd["msg"]); + auto code = get(jcmd["code"]); + *p_msg = msg; + *p_code = code; + return Success(); + } +}; + +struct ShutdownCMD { + [[nodiscard]] Result Send(TCPSocket* peer) const { + Json jcmd{Object{}}; + jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; + auto scmd = Json::Dump(jcmd); + auto n_bytes = peer->Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Failed to send shutdown command from worker."); + } + return Success(); + } +}; +} // namespace xgboost::collective::proto diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 598b41ddd8ba..9ead41859d18 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -17,10 +17,232 @@ #include // for string +#include "../common/json_utils.h" +#include "comm.h" +#include "protocol.h" // for kMagic, PeerInfo +#include "tracker.h" #include "xgboost/collective/result.h" // for Result, Fail, Success #include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ... +#include "xgboost/json.h" namespace xgboost::collective { +Tracker::Tracker(Json const& config) + : n_workers_{static_cast( + RequiredArg(config, "n_workers", __func__))}, + port_{static_cast(OptionalArg(config, "port", Integer::Int{0}))}, + timeout_{std::chrono::seconds{OptionalArg( + config, "timeout", static_cast(collective::DefaultTimeoutSec()))}} {} + +RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr) + : sock_{std::move(sock)} { + auto host = addr.Addr(); + + proto::Magic magic; + rc_ = magic.Verify(&sock_); + if (!rc_.OK()) { + return; + } + + proto::Connect conn; + std::int32_t rank{0}; + rc_ = conn.TrackerRecv(&sock_, &world_, &rank, &task_id_); + if (!rc_.OK()) { + return; + } + + std::string cmd; + sock_.Recv(&cmd); + auto jcmd = Json::Load(StringView{cmd}); + cmd_ = static_cast(get(jcmd["cmd"])); + std::int32_t port{0}; + if (cmd_ == proto::CMD::kStart) { + proto::Start start; + rc_ = start.TrackerHandle(jcmd, &world_, world, &port, &sock_, &eport_); + } else if (cmd_ == proto::CMD::kPrint) { + proto::Print print; + rc_ = print.TrackerHandle(jcmd, &msg_); + } else if (cmd_ == proto::CMD::kError) { + proto::ErrorCMD error; + rc_ = error.TrackerHandle(jcmd, &msg_, &code_); + } + + if (!rc_.OK()) { + return; + } + + info_ = proto::PeerInfo{host, port, rank}; +} + +RabitTracker::RabitTracker(Json const& config) : Tracker{config} { + std::string self; + auto rc = collective::GetHostAddress(&self); + auto host = OptionalArg(config, "host", self); + + listener_ = TCPSocket::Create(SockDomain::kV4); + rc = listener_.Bind(host, &this->port_); + CHECK(rc.OK()) << rc.Report(); + listener_.Listen(); +} + +Result RabitTracker::Bootstrap(std::vector* p_workers) { + auto& workers = *p_workers; + + std::sort(workers.begin(), workers.end(), WorkerCmp{}); + + std::vector bootstrap_threads; + for (std::int32_t r = 0; r < n_workers_; ++r) { + auto& worker = workers[r]; + auto next = BootstrapNext(r, n_workers_); + auto const& next_w = workers[next]; + bootstrap_threads.emplace_back([next, &worker, &next_w] { + auto jnext = proto::PeerInfo{next_w.Host(), next_w.Port(), next}.ToJson(); + std::string str; + Json::Dump(jnext, &str); + worker.Send(StringView{str}); + }); + } + + for (auto& t : bootstrap_threads) { + t.join(); // fixme: check exception + } + + for (auto const& w : workers) { + worker_error_handles_.emplace_back(w.Host(), w.ErrorPort()); + } + return Success(); +} + +[[nodiscard]] std::future RabitTracker::Run() { + // a state machine to keep track of consistency. + struct State { + std::int32_t const n_workers; + + std::int32_t n_shutdown{0}; + bool during_restart{false}; + std::vector pending; + + explicit State(std::int32_t world) : n_workers{world} {} + State(State const& that) = delete; + State& operator=(State&& that) = delete; + + void Start(WorkerProxy&& worker) { + CHECK_LT(pending.size(), n_workers); + CHECK_LE(n_shutdown, n_workers); + + pending.emplace_back(std::forward(worker)); + + CHECK_LE(pending.size(), n_workers); + } + void Shutdown() { + CHECK_GE(n_shutdown, 0); + CHECK_LT(n_shutdown, n_workers); + + ++n_shutdown; + + CHECK_LE(n_shutdown, n_workers); + } + void Error() { + CHECK_LE(pending.size(), n_workers); + CHECK_LE(n_shutdown, n_workers); + + during_restart = true; + } + [[nodiscard]] bool Ready() const { + CHECK_LE(pending.size(), n_workers); + return static_cast(pending.size()) == n_workers; + } + void Bootstrap() { + CHECK_EQ(pending.size(), n_workers); + CHECK_LE(n_shutdown, n_workers); + + // A reset. + n_shutdown = 0; + during_restart = false; + pending.clear(); + } + [[nodiscard]] bool ShouldContinue() const { + CHECK_LE(pending.size(), n_workers); + CHECK_LE(n_shutdown, n_workers); + // - Without error, we should shutdown after all workers are offline. + // - With error, all workers are offline, and we have during_restart as true. + return n_shutdown != n_workers || during_restart; + } + }; + + return std::async(std::launch::async, [this] { + State state{this->n_workers_}; + + while (state.ShouldContinue()) { + TCPSocket sock; + SockAddrV4 addr; + auto rc = listener_.Accept(&sock, &addr); + if (!rc.OK()) { + return Fail("Failed to accept connection.", std::move(rc)); + } + + auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)}; + if (!worker.Status().OK()) { + return Fail("Failed to initialize worker proxy.", std::move(worker.Status())); + } + switch (worker.Command()) { + case proto::CMD::kStart: { + state.Start(std::move(worker)); + if (state.Ready()) { + rc = this->Bootstrap(&state.pending); + state.Bootstrap(); + } + if (!rc.OK()) { + return rc; + } + continue; + } + case proto::CMD::kShutdown: { + state.Shutdown(); + continue; + } + case proto::CMD::kError: { + if (state.during_restart) { + continue; + } + state.Error(); + auto msg = worker.Msg(); + auto code = worker.Code(); + LOG(WARNING) << "Recieved error from [" << worker.Host() << ":" << worker.Rank() + << "]: " << msg << " code:" << code; + auto host = worker.Host(); + // We signal all workers for the error, if they haven't aborted already. + for (auto& w : worker_error_handles_) { + if (w.first == host) { + continue; + } + TCPSocket out; + // retry is set to 1, just let the worker timeout or error. Otherwise the + // tracker and the worker might be waiting for each other. + auto rc = Connect(w.first, w.second, 1, timeout_, &out); + // send signal to stop the worker. + proto::ShutdownCMD shutdown; + rc = shutdown.Send(&out); + if (!rc.OK()) { + return Fail("Failed to inform workers to stop."); + } + } + + continue; + } + case proto::CMD::kPrint: { + LOG(CONSOLE) << worker.Msg(); + continue; + } + case proto::CMD::kInvalid: + default: { + return Fail("Invalid command received."); + } + } + } + return Success(); + }); +} + [[nodiscard]] Result GetHostAddress(std::string* out) { auto rc = GetHostName(out); if (!rc.OK()) { diff --git a/src/collective/tracker.h b/src/collective/tracker.h index ec52f6a628d1..dbdf5c719c96 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -2,11 +2,135 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for string +#include // for seconds +#include // for int32_t +#include // for future +#include // for string +#include "protocol.h" #include "xgboost/collective/result.h" // for Result +#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/json.h" // for Json namespace xgboost::collective { +/** + * + * @brief Implementation of RABIT tracker. + * + * * What is a tracker + * + * The implementation of collective follows what RABIT did in the past. It requires a + * tracker to coordinate initialization and error recovery of workers. While the + * original implementation attempted to attain error resislient inside the collective + * module, which turned out be too challenging due to large amount of external + * states. The new implementation here differs from RABIT in the way that neither state + * recovery nor resislient is handled inside the collective, it merely provides the + * mechanism to signal error to other workers through the use of a centralized tracker. + * + * There are three major functionalities provided the a tracker, namely: + * - Initialization. Share the node addresses among all workers. + * - Logging. + * - Signal error. If an exception is thrown in one (or many) of the workers, it can + * signal an error to the tracker and the tracker will notify other workers. + */ +class Tracker { + protected: + std::int32_t n_workers_{0}; + std::int32_t port_{-1}; + std::chrono::seconds timeout_{0}; + + public: + explicit Tracker(Json const& config); + Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout) + : n_workers_{n_worders}, port_{port}, timeout_{timeout} {} + + virtual ~Tracker() noexcept(false){}; // NOLINT + [[nodiscard]] virtual std::future Run() = 0; + [[nodiscard]] virtual Json WorkerArgs() const = 0; + [[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; } +}; + +class RabitTracker : public Tracker { + // a wrapper for connected worker socket. + class WorkerProxy { + TCPSocket sock_; + proto::PeerInfo info_; + std::int32_t eport_{0}; + std::int32_t world_{-1}; + std::string task_id_; + + proto::CMD cmd_{proto::CMD::kInvalid}; + std::string msg_; + std::int32_t code_{0}; + Result rc_; + + public: + explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr); + WorkerProxy(WorkerProxy const& that) = delete; + WorkerProxy(WorkerProxy&& that) = default; + WorkerProxy& operator=(WorkerProxy const&) = delete; + WorkerProxy& operator=(WorkerProxy&&) = default; + + [[nodiscard]] auto Host() const { return info_.host; } + [[nodiscard]] auto TaskID() const { return task_id_; } + [[nodiscard]] auto Port() const { return info_.port; } + [[nodiscard]] auto Rank() const { return info_.rank; } + [[nodiscard]] auto ErrorPort() const { return eport_; } + [[nodiscard]] auto Command() const { return cmd_; } + [[nodiscard]] auto Msg() const { return msg_; } + [[nodiscard]] auto Code() const { return code_; } + + [[nodiscard]] Result const& Status() const { return rc_; } + [[nodiscard]] Result& Status() { return rc_; } + + void Send(StringView value) { this->sock_.Send(value); } + }; + // provide an ordering for workers, this helps us get deterministic topology. + struct WorkerCmp { + [[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) { + auto const& lh = lhs.Host(); + auto const& rh = rhs.Host(); + + if (lh != rh) { + return lh < rh; + } + return lhs.TaskID() < rhs.TaskID(); + } + }; + + private: + std::string host_; + // record for how to reach out to workers if error happens. + std::vector> worker_error_handles_; + // listening socket for incoming workers. + TCPSocket listener_; + + Result Bootstrap(std::vector* p_workers); + + public: + explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port, + std::chrono::seconds timeout) + : Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} { + listener_ = TCPSocket::Create(SockDomain::kV4); + auto rc = listener_.Bind(host, &this->port_); + CHECK(rc.OK()) << rc.Report(); + listener_.Listen(); + } + + explicit RabitTracker(Json const& config); + ~RabitTracker() noexcept(false) override = default; + + std::future Run() override; + + [[nodiscard]] std::int32_t Port() const { return port_; } + [[nodiscard]] Json WorkerArgs() const override { + Json args{Object{}}; + args["DMLC_TRACKER_URI"] = String{host_}; + args["DMLC_TRACKER_PORT"] = this->Port(); + return args; + } +}; + // Prob the public IP address of the host, need a better method. // // This is directly translated from the previous Python implementation, we should find a diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc new file mode 100644 index 000000000000..759e499aa652 --- /dev/null +++ b/tests/cpp/collective/test_comm.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include + +#include "../../../src/collective/comm.h" +#include "test_worker.h" +namespace xgboost::collective { +namespace { +class CommTest : public TrackerTest {}; +} // namespace + +TEST_F(CommTest, Channel) { + auto n_workers = 4; + RabitTracker tracker{host, n_workers, 0, timeout}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { + WorkerForTest worker{host, port, timeout, n_workers, i}; + if (i % 2 == 0) { + auto p_chan = worker.Comm().Chan(i + 1); + p_chan->SendAll( + EraseType(common::Span{&i, static_cast(1)})); + auto rc = p_chan->Block(); + ASSERT_TRUE(rc.OK()); + } else { + auto p_chan = worker.Comm().Chan(i - 1); + std::int32_t r{-1}; + p_chan->RecvAll(EraseType(common::Span{&r, static_cast(1)})); + auto rc = p_chan->Block(); + ASSERT_TRUE(rc.OK()); + ASSERT_EQ(r, i - 1); + } + }); + } + + for (auto &w : workers) { + w.join(); + } + + ASSERT_TRUE(fut.get().OK()); +} +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index 7802acda8c01..ced795fef9a9 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -7,7 +7,7 @@ #include // EADDRNOTAVAIL #include // std::error_code, std::system_category -#include "net_test.h" // for SocketTest +#include "test_worker.h" // for SocketTest namespace xgboost::collective { TEST_F(SocketTest, Basic) { diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc index 0e60cfb68bad..8fc5f0b3f028 100644 --- a/tests/cpp/collective/test_tracker.cc +++ b/tests/cpp/collective/test_tracker.cc @@ -1,18 +1,67 @@ /** * Copyright 2023, XGBoost Contributors */ -#include "../../../src/collective/tracker.h" // for GetHostAddress -#include "net_test.h" // for SocketTest +#include + +#include // for seconds +#include // for int32_t +#include // for string +#include // for thread +#include // for vector + +#include "../../../src/collective/comm.h" +#include "test_worker.h" namespace xgboost::collective { namespace { -class TrackerTest : public SocketTest {}; +class PrintWorker : public WorkerForTest { + public: + using WorkerForTest::WorkerForTest; + + void Print() { + auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank())); + ASSERT_TRUE(rc.OK()) << rc.Report(); + } +}; } // namespace -TEST_F(TrackerTest, GetHostAddress) { - std::string host; - auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()); - ASSERT_TRUE(host.find("127.") == std::string::npos); +TEST_F(TrackerTest, Bootstrap) { + RabitTracker tracker{host, n_workers, 0, timeout}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; }); + } + for (auto &w : workers) { + w.join(); + } + + ASSERT_TRUE(fut.get().OK()); } + +TEST_F(TrackerTest, Print) { + RabitTracker tracker{host, n_workers, 0, timeout}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { + PrintWorker worker{host, port, timeout, n_workers, i}; + worker.Print(); + }); + } + + for (auto &w : workers) { + w.join(); + } + + ASSERT_TRUE(fut.get().OK()); +} + +TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h new file mode 100644 index 000000000000..9cd352e2eaf8 --- /dev/null +++ b/tests/cpp/collective/test_worker.h @@ -0,0 +1,110 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include + +#include // for seconds +#include // for int32_t +#include // for string +#include // for thread +#include // for move +#include // for vector + +#include "../../../src/collective/comm.h" +#include "../../../src/collective/tracker.h" // for GetHostAddress +#include "../helpers.h" // for FileExists + +namespace xgboost::collective { +class WorkerForTest { + std::string tracker_host_; + std::int32_t tracker_port_; + std::int32_t world_size_; + + protected: + std::int32_t retry_{1}; + std::string task_id_; + RabitComm comm_; + + public: + WorkerForTest(std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t world, std::int32_t rank) + : tracker_host_{std::move(host)}, + tracker_port_{port}, + world_size_{world}, + task_id_{"t:" + std::to_string(rank)}, + comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_} { + CHECK_EQ(world_size_, comm_.World()); + } + virtual ~WorkerForTest() = default; + auto& Comm() { return comm_; } + + void LimitSockBuf(std::int32_t n_bytes) { + for (std::int32_t i = 0; i < comm_.World(); ++i) { + if (i != comm_.Rank()) { + ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); + ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK()); + } + } + } +}; + +class SocketTest : public ::testing::Test { + protected: + std::string skip_msg_{"Skipping IPv6 test"}; + + bool SkipTest() { + std::string path{"/sys/module/ipv6/parameters/disable"}; + if (FileExists(path)) { + std::ifstream fin(path); + if (!fin) { + return true; + } + std::string s_value; + fin >> s_value; + auto value = std::stoi(s_value); + if (value != 0) { + return true; + } + } else { + return true; + } + return false; + } + + protected: + void SetUp() override { system::SocketStartup(); } + void TearDown() override { system::SocketFinalize(); } +}; + +class TrackerTest : public SocketTest { + public: + std::int32_t n_workers{2}; + std::chrono::seconds timeout{1}; + std::string host; + + void SetUp() override { ASSERT_TRUE(GetHostAddress(&host).OK()); } +}; + +template +void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { + std::chrono::seconds timeout{1}; + + std::string host; + ASSERT_TRUE(GetHostAddress(&host).OK()); + RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { worker_fn(host, port, timeout, i); }); + } + + for (auto& t : workers) { + t.join(); + } + + ASSERT_TRUE(fut.get().OK()); +} +} // namespace xgboost::collective From 6eb04cac6a1287c14597634be507138d48eab86e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 13:06:00 +0800 Subject: [PATCH 02/11] let it crash. --- src/collective/tracker.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 9ead41859d18..56888c5961f4 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -103,7 +103,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { } for (auto& t : bootstrap_threads) { - t.join(); // fixme: check exception + t.join(); } for (auto const& w : workers) { From 96979665bc5e56de034c7943ccc20d58c00a7ad4 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 13:28:40 +0800 Subject: [PATCH 03/11] cleanups. --- R-package/src/Makevars.in | 3 +++ R-package/src/Makevars.win | 3 +++ src/collective/comm.cc | 14 +++++++------- src/collective/comm.h | 12 ++++++++++++ src/collective/protocol.h | 1 + src/collective/tracker.cc | 6 +++++- src/collective/tracker.h | 2 ++ 7 files changed, 33 insertions(+), 8 deletions(-) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index e3af418e32ea..63edc775a8c9 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -98,6 +98,9 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/allgather.o \ + $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 8f003403fbcb..6f5ee5fb7afd 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -98,6 +98,9 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/allgather.o \ + $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ diff --git a/src/collective/comm.cc b/src/collective/comm.cc index b1884d58eb2f..cf647122e428 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -3,8 +3,11 @@ */ #include "comm.h" -#include // for unique_lock -#include // for move, forward +#include // for copy +#include // for seconds +#include // for shared_ptr +#include // for string +#include // for move, forward #include "../c_api/c_api_utils.h" #include "../common/common.h" @@ -44,14 +47,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st } TCPSocket& tracker = *out; - proto::Magic magic; - proto::Connect conn; - return std::move(rc) << [&] { return tracker.NonBlocking(false); } << [&] { return tracker.RecvTimeout(timeout); } - << [&] { return magic.Verify(&tracker); } - << [&] { return conn.WorkerSend(&tracker, world, rank, task_id); }; + << [&] { return proto::Magic{}.Verify(&tracker); } + << [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); }; } [[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const { diff --git a/src/collective/comm.h b/src/collective/comm.h index 2756bebb52b1..f23810034e2a 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -2,13 +2,18 @@ * Copyright 2023, XGBoost Contributors */ #pragma once +#include // for seconds #include // for condition_variable #include // for size_t #include // for int32_t +#include // for shared_ptr #include // for mutex #include // for queue +#include // for string #include // for thread #include // for remove_const_t +#include // for move +#include // for vector #include "../common/timer.h" #include "loop.h" // for Loop @@ -22,6 +27,7 @@ namespace xgboost::collective { inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min inline constexpr std::int32_t DefaultRetry() { return 3; } +// indexing into the ring inline std::int32_t BootstrapNext(std::int32_t r, std::int32_t world) { auto nrank = (r + world + 1) % world; return nrank; @@ -34,6 +40,9 @@ inline std::int32_t BootstrapPrev(std::int32_t r, std::int32_t world) { class Channel; +/** + * @brief Base communicator storing info about the tracker and other communicators. + */ class Comm { protected: std::int32_t world_{1}; @@ -99,6 +108,9 @@ class RabitComm : public Comm { [[nodiscard]] Result SignalError(Result const&) override; }; +/** + * @brief Communication channel between workers. + */ class Channel { std::shared_ptr sock_{nullptr}; Result rc_; diff --git a/src/collective/protocol.h b/src/collective/protocol.h index 5ed72fafacc4..96edf4e29bcf 100644 --- a/src/collective/protocol.h +++ b/src/collective/protocol.h @@ -4,6 +4,7 @@ #pragma once #include // for int32_t #include // for string +#include // for move #include "xgboost/collective/result.h" // for Result #include "xgboost/collective/socket.h" // for TCPSocket diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 56888c5961f4..ce589f7dfbde 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -15,7 +15,11 @@ #include #endif // defined(_WIN32) -#include // for string +#include // for sort +#include // for seconds +#include // for int32_t +#include // for string +#include // for move, forward #include "../common/json_utils.h" #include "comm.h" diff --git a/src/collective/tracker.h b/src/collective/tracker.h index dbdf5c719c96..7bbee3c8d6e1 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -6,6 +6,8 @@ #include // for int32_t #include // for future #include // for string +#include // for pair +#include // for vector #include "protocol.h" #include "xgboost/collective/result.h" // for Result From 07b296aa43ecc94517c356b55d88d96ebfbfcae2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 13:39:04 +0800 Subject: [PATCH 04/11] check. --- src/collective/comm.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/collective/comm.cc b/src/collective/comm.cc index cf647122e428..cc062a83fbe3 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -248,6 +248,9 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se std::vector> workers; rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers); + if (!rc.OK()) { + return rc; + } CHECK(this->channels_.empty()); for (auto& w : workers) { From 59201203bfcdef7619ec6dee21043ad205721a5f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 13:56:55 +0800 Subject: [PATCH 05/11] cleanups. --- src/collective/comm.cc | 90 ++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/src/collective/comm.cc b/src/collective/comm.cc index cc062a83fbe3..d26ad3df382a 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -64,21 +64,21 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st std::int32_t retry, std::vector>* out_workers) { auto next = std::make_shared(); + auto prev = std::make_shared(); + auto rc = Success() << [&] { auto rc = Connect(StringView{ninfo.host}, ninfo.port, retry, timeout, next.get()); if (!rc.OK()) { return Fail("Bootstrap failed to connect to ring next.", std::move(rc)); } return rc; - } << [&] { return next->NonBlocking(true); }; - if (!rc.OK()) { - return rc; - } + } << [&] { + return next->NonBlocking(true); + } << [&] { + SockAddrV4 addr; + return listener->Accept(prev.get(), &addr); + } << [&] { return prev->NonBlocking(true); }; - auto prev = std::make_shared(); - SockAddrV4 addr; - rc = std::move(rc) << [&] { return listener->Accept(prev.get(), &addr); } - << [&] { return prev->NonBlocking(true); }; if (!rc.OK()) { return rc; } @@ -94,31 +94,34 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st auto prev_ch = std::make_shared(comm, prev); auto next_ch = std::make_shared(comm, next); - rc = cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch); + + auto block = [&] { + for (auto ch : {prev_ch, next_ch}) { + auto rc = ch->Block(); + if (!rc.OK()) { + return rc; + } + } + return Success(); + }; + + rc = std::move(rc) + << [&] { return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch); } + << [&] { return block(); }; if (!rc.OK()) { return Fail("Failed to get host names from peers.", std::move(rc)); } - for (auto ch : {prev_ch, next_ch}) { - rc = ch->Block(); - if (!rc.OK()) { - return rc; - } - } std::vector peers_port(comm.World(), -1); peers_port[comm.Rank()] = ninfo.port; auto s_ports = common::Span{reinterpret_cast(peers_port.data()), peers_port.size() * sizeof(ninfo.port)}; - rc = cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch); + rc = std::move(rc) + << [&] { return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch); } + << [&] { return block(); }; if (!rc.OK()) { return Fail("Failed to get the port from peers.", std::move(rc)); } - for (auto ch : {prev_ch, next_ch}) { - rc = ch->Block(); - if (!rc.OK()) { - return rc; - } - } std::vector peers(comm.World()); for (auto r = 0; r < comm.World(); ++r) { @@ -162,7 +165,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st if (!rc.OK()) { return rc; } - std::int32_t rank; + std::int32_t rank{-1}; auto n_bytes = peer->RecvAll(&rank, sizeof(rank)); if (n_bytes != sizeof(comm.Rank())) { return Fail("Failed to recv rank."); @@ -278,39 +281,32 @@ RabitComm::~RabitComm() noexcept(false) { [[nodiscard]] Result RabitComm::Shutdown() { TCPSocket tracker; - auto rc = Success() << [&] { + return Success() << [&] { return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); - } << [&] { return this->Block(); }; - if (!rc.OK()) { - return rc; - } - - Json jcmd{Object{}}; - jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; - auto scmd = Json::Dump(jcmd); - tracker.Send(scmd); - - return Success(); + } << [&] { + return this->Block(); + } << [&] { + Json jcmd{Object{}}; + jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; + auto scmd = Json::Dump(jcmd); + auto n_bytes = tracker.Send(scmd); + if (n_bytes != scmd.size()) { + return Fail("Faled to send cmd."); + } + return Success(); + }; } [[nodiscard]] Result RabitComm::LogTracker(std::string msg) const { TCPSocket out; proto::Print print; - auto rc = Success() << [&] { return this->ConnectTracker(&out); } - << [&] { return print.WorkerSend(&out, msg); }; - if (!rc.OK()) { - return Fail("Logging to tracker failed.", std::move(rc)); - } - return rc; + return Success() << [&] { return this->ConnectTracker(&out); } + << [&] { return print.WorkerSend(&out, msg); }; } [[nodiscard]] Result RabitComm::SignalError(Result const& res) { TCPSocket out; - auto rc = this->ConnectTracker(&out); - if (!rc.OK()) { - return Fail("Logging to tracker failed.", std::move(rc)); - } - proto::ErrorCMD cmd; - return cmd.WorkerSend(&out, res); + return Success() << [&] { return this->ConnectTracker(&out); } + << [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); }; } } // namespace xgboost::collective From 3c2a73ced2bf73bb7973f3b7f73a98c1d2704367 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 15:32:01 +0800 Subject: [PATCH 06/11] R. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + 2 files changed, 2 insertions(+) diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 63edc775a8c9..541c0fb5238f 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -104,6 +104,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ + $(PKGROOT)/src/collective/loop.o \ $(PKGROOT)/src/collective/socket.o \ $(PKGROOT)/src/common/charconv.o \ $(PKGROOT)/src/common/column_matrix.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 6f5ee5fb7afd..faacd6d8d7bc 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -104,6 +104,7 @@ OBJECTS= \ $(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ + $(PKGROOT)/src/collective/loop.o \ $(PKGROOT)/src/collective/socket.o \ $(PKGROOT)/src/common/charconv.o \ $(PKGROOT)/src/common/column_matrix.o \ From 95e1c1638d4ad1779f30150a5fa99a5082d20774 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 15:37:26 +0800 Subject: [PATCH 07/11] lint. --- src/collective/comm.cc | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/collective/comm.cc b/src/collective/comm.cc index d26ad3df382a..8c618a69d510 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -9,11 +9,8 @@ #include // for string #include // for move, forward -#include "../c_api/c_api_utils.h" -#include "../common/common.h" #include "allgather.h" -#include "protocol.h" // for kMagic -#include "tracker.h" +#include "protocol.h" // for kMagic #include "xgboost/json.h" // for Json, Object namespace xgboost::collective { @@ -105,20 +102,20 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st return Success(); }; - rc = std::move(rc) - << [&] { return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch); } - << [&] { return block(); }; + rc = std::move(rc) << [&] { + return cpu_impl::RingAllgather(comm, s_buffer, HOST_NAME_MAX, 0, prev_ch, next_ch); + } << [&] { return block(); }; if (!rc.OK()) { return Fail("Failed to get host names from peers.", std::move(rc)); } std::vector peers_port(comm.World(), -1); peers_port[comm.Rank()] = ninfo.port; - auto s_ports = common::Span{reinterpret_cast(peers_port.data()), - peers_port.size() * sizeof(ninfo.port)}; - rc = std::move(rc) - << [&] { return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch); } - << [&] { return block(); }; + rc = std::move(rc) << [&] { + auto s_ports = common::Span{reinterpret_cast(peers_port.data()), + peers_port.size() * sizeof(ninfo.port)}; + return cpu_impl::RingAllgather(comm, s_ports, sizeof(ninfo.port), 0, prev_ch, next_ch); + } << [&] { return block(); }; if (!rc.OK()) { return Fail("Failed to get the port from peers.", std::move(rc)); } From aa38e11976c93ba131e37c8d1cc6ad1f35ab3d5e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 15:45:10 +0800 Subject: [PATCH 08/11] cleanup. --- include/xgboost/collective/result.h | 9 +++++++++ include/xgboost/collective/socket.h | 4 ++-- src/collective/allgather.cc | 6 +++--- src/collective/comm.cc | 25 +++++++------------------ src/collective/tracker.cc | 12 +++--------- 5 files changed, 24 insertions(+), 32 deletions(-) diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h index 209362505fc5..507171dd4ff1 100644 --- a/include/xgboost/collective/result.h +++ b/include/xgboost/collective/result.h @@ -157,4 +157,13 @@ struct Result { [[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { return Result{std::move(msg), std::move(errc), std::forward(prev)}; } + +// We don't have monad, a simple helper would do. +template +Result operator<<(Result&& r, Fn&& fn) { + if (!r.OK()) { + return std::forward(r); + } + return fn(); +} } // namespace xgboost::collective diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 0e8c5faeab57..4178d7cc7ec2 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -383,8 +383,8 @@ class TCPSocket { // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout #if defined(_WIN32) DWORD tv = timeout.count() * 1000; - auto rc = setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), - sizeof(tv)); + auto rc = + setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); #else struct timeval tv; tv.tv_sec = timeout.count(); diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 9b64d36114c8..dba36c88c314 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -4,12 +4,12 @@ #include "allgather.h" #include // for min +#include // for size_t #include // for int8_t #include // for shared_ptr -#include "comm.h" -#include "rabit/internal/socket.h" -#include "xgboost/collective/socket.h" +#include "comm.h" // for Comm, Channel +#include "xgboost/span.h" // for Span namespace xgboost::collective::cpu_impl { Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 8c618a69d510..7de6beba5e2a 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -10,21 +10,11 @@ #include // for move, forward #include "allgather.h" -#include "protocol.h" // for kMagic -#include "xgboost/json.h" // for Json, Object +#include "protocol.h" // for kMagic +#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/json.h" // for Json, Object namespace xgboost::collective { -namespace { -// We don't have monad, a simple helper would do. -template -Result operator<<(Result&& r, Fn&& fn) { - if (!r.OK()) { - return std::forward(r); - } - return fn(); -} -} // namespace - Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry, std::string task_id) : timeout_{timeout}, @@ -34,13 +24,13 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time loop_{std::make_shared(timeout)} {} Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, - std::string const& task_id, xgboost::collective::TCPSocket* out, - std::int32_t rank, std::int32_t world) { + std::string const& task_id, TCPSocket* out, std::int32_t rank, + std::int32_t world) { // get information from tracker CHECK(!info.host.empty()); auto rc = Connect(info.host, info.port, retry, timeout, out); if (!rc.OK()) { - return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); + return Fail("Failed to connect to the tracker.", std::move(rc)); } TCPSocket& tracker = *out; @@ -64,7 +54,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st auto prev = std::make_shared(); auto rc = Success() << [&] { - auto rc = Connect(StringView{ninfo.host}, ninfo.port, retry, timeout, next.get()); + auto rc = Connect(ninfo.host, ninfo.port, retry, timeout, next.get()); if (!rc.OK()) { return Fail("Bootstrap failed to connect to ring next.", std::move(rc)); } @@ -75,7 +65,6 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st SockAddrV4 addr; return listener->Accept(prev.get(), &addr); } << [&] { return prev->NonBlocking(true); }; - if (!rc.OK()) { return rc; } diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index ce589f7dfbde..043e93359581 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -41,15 +41,10 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA : sock_{std::move(sock)} { auto host = addr.Addr(); - proto::Magic magic; - rc_ = magic.Verify(&sock_); - if (!rc_.OK()) { - return; - } - - proto::Connect conn; std::int32_t rank{0}; - rc_ = conn.TrackerRecv(&sock_, &world_, &rank, &task_id_); + rc_ = Success() + << [&] { return proto::Magic{}.Verify(&sock_); } + << [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); }; if (!rc_.OK()) { return; } @@ -69,7 +64,6 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA proto::ErrorCMD error; rc_ = error.TrackerHandle(jcmd, &msg_, &code_); } - if (!rc_.OK()) { return; } From a905f310fae6e818eeaec38cbc40f48296c523c0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 18:27:00 +0800 Subject: [PATCH 09/11] fix windows --- tests/cpp/collective/test_worker.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 9cd352e2eaf8..a3d6de8751df 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -82,7 +82,11 @@ class TrackerTest : public SocketTest { std::chrono::seconds timeout{1}; std::string host; - void SetUp() override { ASSERT_TRUE(GetHostAddress(&host).OK()); } + void SetUp() override { + SocketTest::SetUp(); + auto rc = GetHostAddress(&host); + ASSERT_TRUE(rc.OK()) << rc.Report(); + } }; template From 1e19392dcde2da083eddc0e59f14ce422716e05a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 18:27:36 +0800 Subject: [PATCH 10/11] not used yet --- tests/cpp/collective/test_worker.h | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index a3d6de8751df..3c9d02f036ba 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -88,27 +88,4 @@ class TrackerTest : public SocketTest { ASSERT_TRUE(rc.OK()) << rc.Report(); } }; - -template -void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { - std::chrono::seconds timeout{1}; - - std::string host; - ASSERT_TRUE(GetHostAddress(&host).OK()); - RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; - auto fut = tracker.Run(); - - std::vector workers; - std::int32_t port = tracker.Port(); - - for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { worker_fn(host, port, timeout, i); }); - } - - for (auto& t : workers) { - t.join(); - } - - ASSERT_TRUE(fut.get().OK()); -} } // namespace xgboost::collective From 51386da0a330795fceeb89f8d6bb323cbfd53c81 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Oct 2023 21:40:57 +0800 Subject: [PATCH 11/11] windows --- include/xgboost/collective/socket.h | 15 ++++++++++++++- src/collective/comm.cc | 4 ++++ tests/cpp/collective/test_comm.cc | 4 ++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 4178d7cc7ec2..5dd1b9ffaff2 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -432,7 +432,12 @@ class TCPSocket { */ TCPSocket Accept() { HandleT newfd = accept(Handle(), nullptr, nullptr); - if (newfd == InvalidSocket()) { +#if defined(_WIN32) + auto interrupt = WSAEINTR; +#else + auto interrupt = EINTR; +#endif + if (newfd == InvalidSocket() && system::LastError() != interrupt) { system::ThrowAtError("accept"); } TCPSocket newsock{newfd}; @@ -619,7 +624,15 @@ class TCPSocket { */ void Close() { if (InvalidSocket() != handle_) { +#if defined(_WIN32) + auto rc = system::CloseSocket(handle_); + // it's possible that we close TCP sockets after finalizing WSA due to detached thread. + if (rc != 0 && system::LastError() != WSANOTINITIALISED) { + system::ThrowAtError("close", rc); + } +#else xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); +#endif handle_ = InvalidSocket(); } } diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 7de6beba5e2a..7e0af9c18450 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -200,6 +200,10 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se error_sock->Listen(); error_worker_ = std::thread{[this, error_sock = std::move(error_sock)] { auto conn = error_sock->Accept(); + // On Windows accept returns an invalid socket after network is shutdown. + if (conn.IsClosed()) { + return; + } LOG(WARNING) << "Another worker is running into error."; std::string scmd; conn.Recv(&scmd); diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc index 759e499aa652..7792c4c25059 100644 --- a/tests/cpp/collective/test_comm.cc +++ b/tests/cpp/collective/test_comm.cc @@ -26,13 +26,13 @@ TEST_F(CommTest, Channel) { p_chan->SendAll( EraseType(common::Span{&i, static_cast(1)})); auto rc = p_chan->Block(); - ASSERT_TRUE(rc.OK()); + ASSERT_TRUE(rc.OK()) << rc.Report(); } else { auto p_chan = worker.Comm().Chan(i - 1); std::int32_t r{-1}; p_chan->RecvAll(EraseType(common::Span{&r, static_cast(1)})); auto rc = p_chan->Block(); - ASSERT_TRUE(rc.OK()); + ASSERT_TRUE(rc.OK()) << rc.Report(); ASSERT_EQ(r, i - 1); } });