-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[coll] Add C API for the tracker. (#9773)
- Loading branch information
1 parent
06bdc15
commit 44099f5
Showing
5 changed files
with
264 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#include <chrono> // for seconds | ||
#include <cstddef> // for size_t | ||
#include <future> // for future | ||
#include <memory> // for unique_ptr | ||
#include <string> // for string | ||
#include <type_traits> // for is_same_v, remove_pointer_t | ||
#include <utility> // for pair | ||
|
||
#include "../collective/tracker.h" // for RabitTracker | ||
#include "c_api_error.h" // for API_BEGIN | ||
#include "xgboost/c_api.h" | ||
#include "xgboost/collective/result.h" // for Result | ||
#include "xgboost/json.h" // for Json | ||
#include "xgboost/string_view.h" // for StringView | ||
|
||
#if defined(XGBOOST_USE_FEDERATED) | ||
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker | ||
#else | ||
#include "../common/error_msg.h" // for NoFederated | ||
#endif | ||
|
||
using namespace xgboost; // NOLINT | ||
|
||
namespace { | ||
using TrackerHandleT = | ||
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>; | ||
|
||
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) { | ||
xgboost_CHECK_C_ARG_PTR(handle); | ||
auto *ptr = static_cast<TrackerHandleT *>(handle); | ||
CHECK(ptr); | ||
return ptr; | ||
} | ||
|
||
struct CollAPIEntry { | ||
std::string ret_str; | ||
}; | ||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>; | ||
|
||
void WaitImpl(TrackerHandleT *ptr) { | ||
std::chrono::seconds wait_for{100}; | ||
auto fut = ptr->second; | ||
while (fut.valid()) { | ||
auto res = fut.wait_for(wait_for); | ||
CHECK(res != std::future_status::deferred); | ||
if (res == std::future_status::ready) { | ||
auto const &rc = ptr->second.get(); | ||
CHECK(rc.OK()) << rc.Report(); | ||
break; | ||
} | ||
} | ||
} | ||
} // namespace | ||
|
||
XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) { | ||
API_BEGIN(); | ||
xgboost_CHECK_C_ARG_PTR(config); | ||
|
||
Json jconfig = Json::Load(config); | ||
|
||
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__); | ||
std::unique_ptr<collective::Tracker> tptr; | ||
if (type == "federated") { | ||
#if defined(XGBOOST_USE_FEDERATED) | ||
tptr = std::make_unique<collective::FederatedTracker>(jconfig); | ||
#else | ||
LOG(FATAL) << error::NoFederated(); | ||
#endif // defined(XGBOOST_USE_FEDERATED) | ||
} else if (type == "rabit") { | ||
tptr = std::make_unique<collective::RabitTracker>(jconfig); | ||
} else { | ||
LOG(FATAL) << "Unknown communicator:" << type; | ||
} | ||
|
||
auto ptr = new TrackerHandleT{std::move(tptr), std::future<collective::Result>{}}; | ||
static_assert(std::is_same_v<std::remove_pointer_t<decltype(ptr)>, TrackerHandleT>); | ||
|
||
xgboost_CHECK_C_ARG_PTR(handle); | ||
*handle = ptr; | ||
API_END(); | ||
} | ||
|
||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) { | ||
API_BEGIN(); | ||
auto *ptr = GetTrackerHandle(handle); | ||
auto &local = *CollAPIThreadLocalStore::Get(); | ||
local.ret_str = Json::Dump(ptr->first->WorkerArgs()); | ||
xgboost_CHECK_C_ARG_PTR(args); | ||
*args = local.ret_str.c_str(); | ||
API_END(); | ||
} | ||
|
||
XGB_DLL int XGTrackerRun(TrackerHandle handle) { | ||
API_BEGIN(); | ||
auto *ptr = GetTrackerHandle(handle); | ||
CHECK(!ptr->second.valid()) << "Tracker is already running."; | ||
ptr->second = ptr->first->Run(); | ||
API_END(); | ||
} | ||
|
||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { | ||
API_BEGIN(); | ||
auto *ptr = GetTrackerHandle(handle); | ||
xgboost_CHECK_C_ARG_PTR(config); | ||
auto jconfig = Json::Load(StringView{config}); | ||
WaitImpl(ptr); | ||
API_END(); | ||
} | ||
|
||
XGB_DLL int XGTrackerFree(TrackerHandle handle) { | ||
API_BEGIN(); | ||
auto *ptr = GetTrackerHandle(handle); | ||
WaitImpl(ptr); | ||
delete ptr; | ||
API_END(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#include <gtest/gtest.h> | ||
#include <xgboost/c_api.h> | ||
|
||
#include <chrono> // for ""s | ||
#include <thread> // for thread | ||
|
||
#include "../../../src/collective/tracker.h" | ||
#include "test_worker.h" // for SocketTest | ||
#include "xgboost/json.h" // for Json | ||
|
||
namespace xgboost::collective { | ||
namespace { | ||
class TrackerAPITest : public SocketTest {}; | ||
} // namespace | ||
|
||
TEST_F(TrackerAPITest, CAPI) { | ||
TrackerHandle handle; | ||
Json config{Object{}}; | ||
config["dmlc_communicator"] = String{"rabit"}; | ||
config["n_workers"] = 2; | ||
config["timeout"] = 1; | ||
auto config_str = Json::Dump(config); | ||
auto rc = XGTrackerCreate(config_str.c_str(), &handle); | ||
ASSERT_EQ(rc, 0); | ||
rc = XGTrackerRun(handle); | ||
ASSERT_EQ(rc, 0); | ||
|
||
std::thread bg_wait{[&] { | ||
Json config{Object{}}; | ||
auto config_str = Json::Dump(config); | ||
auto rc = XGTrackerWait(handle, config_str.c_str()); | ||
ASSERT_EQ(rc, 0); | ||
}}; | ||
|
||
char const* cargs; | ||
rc = XGTrackerWorkerArgs(handle, &cargs); | ||
ASSERT_EQ(rc, 0); | ||
auto args = Json::Load(StringView{cargs}); | ||
|
||
std::string host; | ||
ASSERT_TRUE(GetHostAddress(&host).OK()); | ||
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"])); | ||
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]); | ||
ASSERT_NE(port, 0); | ||
|
||
std::vector<std::thread> workers; | ||
using namespace std::chrono_literals; // NOLINT | ||
for (std::int32_t r = 0; r < 2; ++r) { | ||
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; }); | ||
} | ||
for (auto& w : workers) { | ||
w.join(); | ||
} | ||
|
||
rc = XGTrackerFree(handle); | ||
ASSERT_EQ(rc, 0); | ||
|
||
bg_wait.join(); | ||
} | ||
} // namespace xgboost::collective |