Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Federated comm. #9732

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_server.cc)
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
114 changes: 114 additions & 0 deletions plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "federated_comm.h"

#include <grpcpp/grpcpp.h>

#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <string> // for string, stoi

#include "../../src/common/common.h" // for Split
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
#include "xgboost/logging.h"

namespace xgboost::collective {
void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank, std::string const& server_cert,
std::string const& client_key, std::string const& client_cert) {
this->rank_ = rank;
this->world_ = world;

this->tracker_.host = host;
this->tracker_.port = port;
this->tracker_.rank = rank;

CHECK_GE(world, 1) << "Invalid world size.";
CHECK_GE(rank, 0) << "Invalid worker rank.";
CHECK_LT(rank, world) << "Invalid worker rank.";

if (server_cert.empty()) {
stub_ = [&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return federated::Federated::NewStub(
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
}();
} else {
stub_ = [&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
}();
}
}

FederatedComm::FederatedComm(Json const& config) {
/**
* Topology
*/
std::string server_address{};
std::int32_t world_size{0};
std::int32_t rank{-1};
// Parse environment variables first.
auto* value = std::getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = std::getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = std::getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}

server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));

auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;

CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
CHECK(!server_address.empty()) << "Parameter `federated_server_address` is required.";

/**
* Certificates
*/
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
value = getenv("FEDERATED_SERVER_CERT_PATH");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY_PATH");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT_PATH");
if (value != nullptr) {
client_cert = value;
}

server_cert = OptionalArg<String>(config, "federated_server_cert_path", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);

this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
} // namespace xgboost::collective
53 changes: 53 additions & 0 deletions plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2023, XGBoost contributors
*/
#pragma once

#include <federated.grpc.pb.h>
#include <federated.pb.h>

#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"

namespace xgboost::collective {
class FederatedComm : public Comm {
std::unique_ptr<federated::Federated::Stub> stub_;

void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);

public:
/**
* @param config
*
* - federated_server_address: Tracker address
* - federated_world_size: The number of workers
* - federated_rank: Rank of federated worker
* - federated_server_cert_path
* - federated_client_key_path
* - federated_client_cert_path
*/
explicit FederatedComm(Json const& config);
explicit FederatedComm(std::string const& host, std::int32_t port, std::int32_t world,
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
~FederatedComm() override { stub_.reset(); }

[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
LOG(FATAL) << "peer to peer communication is not allowed for federated learning.";
return nullptr;
}
[[nodiscard]] Result LogTracker(std::string msg) const override {
LOG(CONSOLE) << msg;
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
};
} // namespace xgboost::collective
7 changes: 5 additions & 2 deletions plugin/federated/federated_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#include "federated_server.h"

#include <grpcpp/grpcpp.h>
#include <grpcpp/server.h> // for Server
#include <grpcpp/server_builder.h>
#include <xgboost/logging.h>

#include <sstream>

#include "../../src/collective/comm.h"
#include "../../src/common/io.h"
#include "../../src/common/json_utils.h"

namespace xgboost::federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
Expand Down Expand Up @@ -46,7 +49,7 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};

grpc::ServerBuilder builder;
auto options =
Expand All @@ -68,7 +71,7 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,

void RunInsecureServer(int port, std::size_t world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};
FederatedService service{static_cast<std::int32_t>(world_size)};

grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
Expand Down
20 changes: 11 additions & 9 deletions plugin/federated/federated_server.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once

#include <federated.grpc.pb.h>

#include "../../src/collective/in_memory_handler.h"
#include <cstdint> // for int32_t
#include <future> // for future

namespace xgboost {
namespace federated {
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result

namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}

grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
Expand All @@ -34,6 +38,4 @@ void RunServer(int port, std::size_t world_size, char const* server_key_file,
char const* server_cert_file, char const* client_cert_file);

void RunInsecureServer(int port, std::size_t world_size);

} // namespace federated
} // namespace xgboost
} // namespace xgboost::federated
101 changes: 101 additions & 0 deletions plugin/federated/federated_tracker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "federated_tracker.h"

#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
#include <grpcpp/server_builder.h> // for ServerBuilder

#include <chrono> // for ms
#include <cstdint> // for int32_t
#include <exception> // for exception
#include <limits> // for numeric_limits
#include <string> // for string
#include <thread> // for sleep_for

#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
#include "federated_server.h" // for FederatedService

namespace xgboost::collective {
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
}
}

std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
grpc::ServerBuilder builder;

if (this->server_cert_file_.empty()) {
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
} else {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< this->n_workers_;
} else {
auto options = grpc::SslServerCredentialsOptions(
GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file_);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = xgboost::common::ReadAll(server_key_path_);
key.cert_chain = xgboost::common::ReadAll(server_cert_file_);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
if (this->port_ == 0) {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options), &port_);
} else {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< n_workers_;
}

try {
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});
}
return collective::Success();
});
}

FederatedTracker::~FederatedTracker() = default;

Result FederatedTracker::Shutdown() {
common::Timer timer;
timer.Start();
using namespace std::chrono_literals;
while (!server_) {
timer.Stop();
auto ela = timer.ElapsedSeconds();
if (ela > this->Timeout().count()) {
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
std::this_thread::sleep_for(10ms);
}

try {
server_->Shutdown();
} catch (std::exception const& e) {
return Fail("Failed to shutdown:" + std::string{e.what()});
}

return Success();
}
} // namespace xgboost::collective
41 changes: 41 additions & 0 deletions plugin/federated/federated_tracker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h> // for Server

#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string

#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json

namespace xgboost::collective {
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;
std::string server_cert_file_;
std::string client_cert_file_;

public:
/**
* @brief CTOR
*
* @param config Configuration, other than the base configuration from Tracker, we have:
*
* - federated_secure: bool whether this is a secure server.
* - server_key_path: path to the key.
* - server_cert_path: certificate path.
* - client_cert_path: certificate path for client.
*/
explicit FederatedTracker(Json const& config);
~FederatedTracker() override;
std::future<Result> Run() override;
// federated tracker do not provide initialization parameters, users have to provide it
// themseleves.
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
[[nodiscard]] Result Shutdown();
};
} // namespace xgboost::collective
Loading
Loading