Skip to content

Commit

Permalink
[coll] Add C API for the tracker. (#9773)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Nov 8, 2023
1 parent 06bdc15 commit 44099f5
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 0 deletions.
77 changes: 77 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,83 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
* @{
*/

/**
* @brief Handle to tracker.
*
* There are currently two types of tracker in XGBoost, first one is `rabit`, while the
* other one is `federated`.
*
* This is still under development.
*/
typedef void *TrackerHandle; /* NOLINT */

/**
* @brief Create a new tracker.
*
* @param config JSON encoded parameters.
*
* - dmlc_communicator: String, the type of tracker to create. Available options are `rabit`
* and `federated`.
* - n_workers: Integer, the number of workers.
* - port: (Optional) Integer, the port this tracker should listen to.
* - timeout: (Optional) Integer, timeout in seconds for various networking operations.
*
* Some configurations are `rabit` specific:
* - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host.
*
* Some `federated` specific configurations:
* - federated_secure: Boolean, whether this is a secure server.
* - server_key_path: Path to the server key. Used only if this is a secure server.
* - server_cert_path: Path to the server certificate. Used only if this is a secure server.
* - client_cert_path: Path to the client certificate. Used only if this is a secure server.
*
* @param handle The handle to the created tracker.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);

/**
* @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait()
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);

/**
* @brief Run the tracker.
*
* @param handle The handle to the tracker.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerRun(TrackerHandle handle);

/**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
* the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);

/**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
* cannot close properly, manual interruption is required.
*
* @param handle The handle to the tracker.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerFree(TrackerHandle handle);

/*!
* \brief Initialize the collective communicator.
*
Expand Down
119 changes: 119 additions & 0 deletions src/c_api/coll_c_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair

#include "../collective/tracker.h" // for RabitTracker
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
#include "xgboost/string_view.h" // for StringView

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
#else
#include "../common/error_msg.h" // for NoFederated
#endif

using namespace xgboost; // NOLINT

namespace {
using TrackerHandleT =
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;

TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
xgboost_CHECK_C_ARG_PTR(handle);
auto *ptr = static_cast<TrackerHandleT *>(handle);
CHECK(ptr);
return ptr;
}

struct CollAPIEntry {
std::string ret_str;
};
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;

void WaitImpl(TrackerHandleT *ptr) {
std::chrono::seconds wait_for{100};
auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
CHECK(res != std::future_status::deferred);
if (res == std::future_status::ready) {
auto const &rc = ptr->second.get();
CHECK(rc.OK()) << rc.Report();
break;
}
}
}
} // namespace

XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(config);

Json jconfig = Json::Load(config);

auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
std::unique_ptr<collective::Tracker> tptr;
if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
#else
LOG(FATAL) << error::NoFederated();
#endif // defined(XGBOOST_USE_FEDERATED)
} else if (type == "rabit") {
tptr = std::make_unique<collective::RabitTracker>(jconfig);
} else {
LOG(FATAL) << "Unknown communicator:" << type;
}

auto ptr = new TrackerHandleT{std::move(tptr), std::future<collective::Result>{}};
static_assert(std::is_same_v<std::remove_pointer_t<decltype(ptr)>, TrackerHandleT>);

xgboost_CHECK_C_ARG_PTR(handle);
*handle = ptr;
API_END();
}

XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
auto &local = *CollAPIThreadLocalStore::Get();
local.ret_str = Json::Dump(ptr->first->WorkerArgs());
xgboost_CHECK_C_ARG_PTR(args);
*args = local.ret_str.c_str();
API_END();
}

XGB_DLL int XGTrackerRun(TrackerHandle handle) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
CHECK(!ptr->second.valid()) << "Tracker is already running.";
ptr->second = ptr->first->Run();
API_END();
}

XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
WaitImpl(ptr);
API_END();
}

XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
WaitImpl(ptr);
delete ptr;
API_END();
}
3 changes: 3 additions & 0 deletions src/collective/tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ class RabitTracker : public Tracker {
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
//
// At the moment, the listener calls accept without first polling. We can add an
// additional unix domain socket to allow cancelling the accept.
TCPSocket listener_;

Result Bootstrap(std::vector<WorkerProxy>* p_workers);
Expand Down
2 changes: 2 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,7 @@ constexpr StringView InvalidCUDAOrdinal() {
}

void MismatchedDevices(Context const* booster, Context const* data);

inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
63 changes: 63 additions & 0 deletions tests/cpp/collective/test_coll_c_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/c_api.h>

#include <chrono> // for ""s
#include <thread> // for thread

#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for SocketTest
#include "xgboost/json.h" // for Json

namespace xgboost::collective {
namespace {
class TrackerAPITest : public SocketTest {};
} // namespace

TEST_F(TrackerAPITest, CAPI) {
TrackerHandle handle;
Json config{Object{}};
config["dmlc_communicator"] = String{"rabit"};
config["n_workers"] = 2;
config["timeout"] = 1;
auto config_str = Json::Dump(config);
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
ASSERT_EQ(rc, 0);
rc = XGTrackerRun(handle);
ASSERT_EQ(rc, 0);

std::thread bg_wait{[&] {
Json config{Object{}};
auto config_str = Json::Dump(config);
auto rc = XGTrackerWait(handle, config_str.c_str());
ASSERT_EQ(rc, 0);
}};

char const* cargs;
rc = XGTrackerWorkerArgs(handle, &cargs);
ASSERT_EQ(rc, 0);
auto args = Json::Load(StringView{cargs});

std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
ASSERT_NE(port, 0);

std::vector<std::thread> workers;
using namespace std::chrono_literals; // NOLINT
for (std::int32_t r = 0; r < 2; ++r) {
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
}
for (auto& w : workers) {
w.join();
}

rc = XGTrackerFree(handle);
ASSERT_EQ(rc, 0);

bg_wait.join();
}
} // namespace xgboost::collective

0 comments on commit 44099f5

Please sign in to comment.