diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index d28b5098be9e..ffa3a6c79f42 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1508,6 +1508,83 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, * @{ */ +/** + * @brief Handle to tracker. + * + * There are currently two types of tracker in XGBoost, first one is `rabit`, while the + * other one is `federated`. + * + * This is still under development. + */ +typedef void *TrackerHandle; /* NOLINT */ + +/** + * @brief Create a new tracker. + * + * @param config JSON encoded parameters. + * + * - dmlc_communicator: String, the type of tracker to create. Available options are `rabit` + * and `federated`. + * - n_workers: Integer, the number of workers. + * - port: (Optional) Integer, the port this tracker should listen to. + * - timeout: (Optional) Integer, timeout in seconds for various networking operations. + * + * Some configurations are `rabit` specific: + * - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host. + * + * Some `federated` specific configurations: + * - federated_secure: Boolean, whether this is a secure server. + * - server_key_path: Path to the server key. Used only if this is a secure server. + * - server_cert_path: Path to the server certificate. Used only if this is a secure server. + * - client_cert_path: Path to the client certificate. Used only if this is a secure server. + * + * @param handle The handle to the created tracker. + * + * @return 0 for success, -1 for failure. + */ +XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); + +/** + * @brief Get the arguments needed for running workers. This should be called after + * XGTrackerRun() and XGTrackerWait() + * + * @param handle The handle to the tracker. + * @param args The arguments returned as a JSON document. + * + * @return 0 for success, -1 for failure. + */ +XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args); + +/** + * @brief Run the tracker. + * + * @param handle The handle to the tracker. + * + * @return 0 for success, -1 for failure. + */ +XGB_DLL int XGTrackerRun(TrackerHandle handle); + +/** + * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). + * + * @param handle The handle to the tracker. + * @param config JSON encoded configuration. No argument is required yet, preserved for + * the future. + * + * @return 0 for success, -1 for failure. + */ +XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config); + +/** + * @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker + * cannot close properly, manual interruption is required. + * + * @param handle The handle to the tracker. + * + * @return 0 for success, -1 for failure. + */ +XGB_DLL int XGTrackerFree(TrackerHandle handle); + /*! * \brief Initialize the collective communicator. * diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc new file mode 100644 index 000000000000..01713dbad419 --- /dev/null +++ b/src/c_api/coll_c_api.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include // for seconds +#include // for size_t +#include // for future +#include // for unique_ptr +#include // for string +#include // for is_same_v, remove_pointer_t +#include // for pair + +#include "../collective/tracker.h" // for RabitTracker +#include "c_api_error.h" // for API_BEGIN +#include "xgboost/c_api.h" +#include "xgboost/collective/result.h" // for Result +#include "xgboost/json.h" // for Json +#include "xgboost/string_view.h" // for StringView + +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker +#else +#include "../common/error_msg.h" // for NoFederated +#endif + +using namespace xgboost; // NOLINT + +namespace { +using TrackerHandleT = + std::pair, std::shared_future>; + +TrackerHandleT *GetTrackerHandle(TrackerHandle handle) { + xgboost_CHECK_C_ARG_PTR(handle); + auto *ptr = static_cast(handle); + CHECK(ptr); + return ptr; +} + +struct CollAPIEntry { + std::string ret_str; +}; +using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; + +void WaitImpl(TrackerHandleT *ptr) { + std::chrono::seconds wait_for{100}; + auto fut = ptr->second; + while (fut.valid()) { + auto res = fut.wait_for(wait_for); + CHECK(res != std::future_status::deferred); + if (res == std::future_status::ready) { + auto const &rc = ptr->second.get(); + CHECK(rc.OK()) << rc.Report(); + break; + } + } +} +} // namespace + +XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) { + API_BEGIN(); + xgboost_CHECK_C_ARG_PTR(config); + + Json jconfig = Json::Load(config); + + auto type = RequiredArg(jconfig, "dmlc_communicator", __func__); + std::unique_ptr tptr; + if (type == "federated") { +#if defined(XGBOOST_USE_FEDERATED) + tptr = std::make_unique(jconfig); +#else + LOG(FATAL) << error::NoFederated(); +#endif // defined(XGBOOST_USE_FEDERATED) + } else if (type == "rabit") { + tptr = std::make_unique(jconfig); + } else { + LOG(FATAL) << "Unknown communicator:" << type; + } + + auto ptr = new TrackerHandleT{std::move(tptr), std::future{}}; + static_assert(std::is_same_v, TrackerHandleT>); + + xgboost_CHECK_C_ARG_PTR(handle); + *handle = ptr; + API_END(); +} + +XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) { + API_BEGIN(); + auto *ptr = GetTrackerHandle(handle); + auto &local = *CollAPIThreadLocalStore::Get(); + local.ret_str = Json::Dump(ptr->first->WorkerArgs()); + xgboost_CHECK_C_ARG_PTR(args); + *args = local.ret_str.c_str(); + API_END(); +} + +XGB_DLL int XGTrackerRun(TrackerHandle handle) { + API_BEGIN(); + auto *ptr = GetTrackerHandle(handle); + CHECK(!ptr->second.valid()) << "Tracker is already running."; + ptr->second = ptr->first->Run(); + API_END(); +} + +XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { + API_BEGIN(); + auto *ptr = GetTrackerHandle(handle); + xgboost_CHECK_C_ARG_PTR(config); + auto jconfig = Json::Load(StringView{config}); + WaitImpl(ptr); + API_END(); +} + +XGB_DLL int XGTrackerFree(TrackerHandle handle) { + API_BEGIN(); + auto *ptr = GetTrackerHandle(handle); + WaitImpl(ptr); + delete ptr; + API_END(); +} diff --git a/src/collective/tracker.h b/src/collective/tracker.h index 24e47bb4e776..f336a82f9ee5 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -114,6 +114,9 @@ class RabitTracker : public Tracker { // record for how to reach out to workers if error happens. std::vector> worker_error_handles_; // listening socket for incoming workers. + // + // At the moment, the listener calls accept without first polling. We can add an + // additional unix domain socket to allow cancelling the accept. TCPSocket listener_; Result Bootstrap(std::vector* p_workers); diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 94703fd15c83..995fe11d5191 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -97,5 +97,7 @@ constexpr StringView InvalidCUDAOrdinal() { } void MismatchedDevices(Context const* booster, Context const* data); + +inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; } } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/tests/cpp/collective/test_coll_c_api.cc b/tests/cpp/collective/test_coll_c_api.cc new file mode 100644 index 000000000000..d80fbc14073d --- /dev/null +++ b/tests/cpp/collective/test_coll_c_api.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include +#include + +#include // for ""s +#include // for thread + +#include "../../../src/collective/tracker.h" +#include "test_worker.h" // for SocketTest +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +namespace { +class TrackerAPITest : public SocketTest {}; +} // namespace + +TEST_F(TrackerAPITest, CAPI) { + TrackerHandle handle; + Json config{Object{}}; + config["dmlc_communicator"] = String{"rabit"}; + config["n_workers"] = 2; + config["timeout"] = 1; + auto config_str = Json::Dump(config); + auto rc = XGTrackerCreate(config_str.c_str(), &handle); + ASSERT_EQ(rc, 0); + rc = XGTrackerRun(handle); + ASSERT_EQ(rc, 0); + + std::thread bg_wait{[&] { + Json config{Object{}}; + auto config_str = Json::Dump(config); + auto rc = XGTrackerWait(handle, config_str.c_str()); + ASSERT_EQ(rc, 0); + }}; + + char const* cargs; + rc = XGTrackerWorkerArgs(handle, &cargs); + ASSERT_EQ(rc, 0); + auto args = Json::Load(StringView{cargs}); + + std::string host; + ASSERT_TRUE(GetHostAddress(&host).OK()); + ASSERT_EQ(host, get(args["DMLC_TRACKER_URI"])); + auto port = get(args["DMLC_TRACKER_PORT"]); + ASSERT_NE(port, 0); + + std::vector workers; + using namespace std::chrono_literals; // NOLINT + for (std::int32_t r = 0; r < 2; ++r) { + workers.emplace_back([=] { WorkerForTest w{host, static_cast(port), 1s, 2, r}; }); + } + for (auto& w : workers) { + w.join(); + } + + rc = XGTrackerFree(handle); + ASSERT_EQ(rc, 0); + + bg_wait.join(); +} +} // namespace xgboost::collective