Skip to content

Commit

Permalink
[coll] Improve event loop. (#10199)
Browse files Browse the repository at this point in the history
- Add a test for blocking calls.
- Do not require the queue to be empty after waking up; this frees up the thread to answer blocking calls.
- Handle EOF in read.
- Improve the error message in the result. Allow concatenation of multiple results.
  • Loading branch information
trivialfis authored Apr 17, 2024
1 parent 7c0c967 commit 4b10200
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 111 deletions.
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
Expand Down
2 changes: 1 addition & 1 deletion demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(client):
# you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history)
return prediction
print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions doc/contrib/unit_tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ which provides higher flexibility. For example:
ctest --verbose
If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:

.. code-block::
::testing::GTEST_FLAG(filter) = "Suite.Test";
::testing::GTEST_FLAG(repeat) = 10;
***********************************************
Sanitizers: Detect memory errors and data races
***********************************************
Expand Down
107 changes: 47 additions & 60 deletions include/xgboost/collective/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
*/
#pragma once

#include <xgboost/logging.h>

#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
#include <string> // for string
#include <utility> // for move
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <system_error> // for error_code
#include <utility> // for move

namespace xgboost::collective {
namespace detail {
Expand Down Expand Up @@ -48,48 +46,19 @@ struct ResultImpl {
return cur_eq;
}

[[nodiscard]] std::string Report() {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}

auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;

if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
[[nodiscard]] std::string Report() const;
[[nodiscard]] std::error_code Code() const;

return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
void Concat(std::unique_ptr<ResultImpl> rhs);
};

#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif
} // namespace detail

/**
Expand Down Expand Up @@ -131,47 +100,65 @@ struct Result {
}
return *impl_ == *that.impl_;
}

friend Result operator+(Result&& lhs, Result&& rhs);
};

[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
if (lhs.OK()) {
return std::forward<Result>(rhs);
}
if (rhs.OK()) {
return std::forward<Result>(lhs);
}
lhs.impl_->Concat(std::move(rhs.impl_));
return std::forward<Result>(lhs);
}

/**
* @brief Return success.
*/
[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; }
/**
* @brief Return failure.
*/
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line)};
}
/**
* @brief Return failure with `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
return Result{std::move(msg), std::move(errc)};
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
}
/**
* @brief Return failure with a previous error.
*/
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
return Result{std::move(msg), std::forward<Result>(prev)};
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
}
/**
* @brief Return failure with a previous error and a new `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
std::forward<Result>(prev)};
}

// We don't have monad, a simple helper would do.
template <typename Fn>
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
[[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) {
return std::forward<Result>(r);
}
return fn();
}

inline void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
void SafeColl(Result const& rc);
} // namespace xgboost::collective
32 changes: 23 additions & 9 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, XGBoost Contributors
* Copyright (c) 2022-2024, XGBoost Contributors
*/
#pragma once

Expand All @@ -12,7 +12,6 @@
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t, std::uint16_t
#include <cstring> // memset
#include <limits> // std::numeric_limits
#include <string> // std::string
#include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap
Expand Down Expand Up @@ -468,19 +467,30 @@ class TCPSocket {
*addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd};
}
// On MacOS, this is automatically set to async socket if the parent socket is async
// We make sure all socket are blocking by default.
//
// On Windows, a closed socket is returned during shutdown. We guard against it when
// setting non-blocking.
if (!out->IsClosed()) {
return out->NonBlocking(false);
}
return Success();
}

~TCPSocket() {
if (!IsClosed()) {
Close();
auto rc = this->Close();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
}
}

TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete;
TCPSocket &operator=(TCPSocket &&that) {
TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
std::swap(this->handle_, that.handle_);
return *this;
}
Expand Down Expand Up @@ -635,22 +645,26 @@ class TCPSocket {
*/
std::size_t Recv(std::string *p_str);
/**
* \brief Close the socket, called automatically in destructor if the socket is not closed.
* @brief Close the socket, called automatically in destructor if the socket is not closed.
*/
void Close() {
Result Close() {
if (InvalidSocket() != handle_) {
#if defined(_WIN32)
auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
system::ThrowAtError("close", rc);
return system::FailWithCode("Failed to close the socket.");
}
#else
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
if (rc != 0) {
return system::FailWithCode("Failed to close the socket.");
}
#endif
handle_ = InvalidSocket();
}
return Success();
}

/**
* \brief Create a TCP socket on specified domain.
*/
Expand Down
Loading

0 comments on commit 4b10200

Please sign in to comment.