Skip to content

Commit

Permalink
[fed] Replace secure split with query in coll. (dmlc#10542)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 6, 2024
1 parent d83c22b commit 733af8d
Show file tree
Hide file tree
Showing 18 changed files with 89 additions and 50 deletions.
14 changes: 4 additions & 10 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };
enum class DataSplitMode : int { kRow = 0, kCol = 1 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -174,17 +174,11 @@ class MetaInfo {
*/
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
/** @brief Whether the data is split row-wise. */
[[nodiscard]] bool IsRowSplit() const { return data_split_mode == DataSplitMode::kRow; }

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
[[nodiscard]] bool IsColumnSplit() const { return !this->IsRowSplit(); }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }
Expand Down
10 changes: 9 additions & 1 deletion plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <limits> // for numeric_limits
#include <memory> // for make_shared
#include <string> // for string, stoi

#include "../../src/common/common.h" // for Split
Expand All @@ -32,7 +33,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
CHECK_LT(rank, world) << "Invalid worker rank.";

auto certs = {server_cert, client_cert, client_cert};
auto is_empty = [](auto const& s) { return s.empty(); };
auto is_empty = [](auto const& s) {
return s.empty();
};
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
std::none_of(certs.begin(), certs.end(), is_empty);
CHECK(valid) << "Invalid arguments for certificates.";
Expand Down Expand Up @@ -123,6 +126,11 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);

/**
* Hist encryption plugin.
*/
this->plugin_.reset(CreateFederatedPlugin(config));

this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
Expand Down
8 changes: 7 additions & 1 deletion plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include <memory> // for shared_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
#include "../../src/collective/comm.h" // for HostComm
#include "federated_plugin.h" // for FederatedPlugin
#include "xgboost/json.h"

namespace xgboost::collective {
class FederatedComm : public HostComm {
std::shared_ptr<federated::Federated::Stub> stub_;
// Plugin for encryption
std::shared_ptr<FederatedPluginBase> plugin_{nullptr};

void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
Expand Down Expand Up @@ -62,6 +65,7 @@ class FederatedComm : public HostComm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] bool IsEncrypted() const override { return static_cast<bool>(plugin_); }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }

[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
Expand All @@ -73,5 +77,7 @@ class FederatedComm : public HostComm {
*out = "rank:" + std::to_string(rank);
return Success();
};

auto EncryptionPlugin() const { return plugin_; }
};
} // namespace xgboost::collective
18 changes: 11 additions & 7 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_comm.h"
#endif // defined(XGBOOST_USE_FEDERATED)

namespace xgboost::collective {
namespace detail {
// Apply function fn, and handle potential errors.
template <typename Fn>
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
std::string msg;
Expand All @@ -29,10 +33,10 @@ template <typename Fn>
msg = e.what();
}
}
// Error handling
std::size_t msg_size{msg.size()};
auto rc = Success() << [&] {
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
return rc;
return collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
} << [&] {
if (msg_size > 0) {
msg.resize(msg_size);
Expand Down Expand Up @@ -95,10 +99,10 @@ template <typename T, typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
auto rc = detail::TryApplyWithLabels(ctx, fn);

// We assume labels are only available on worker 0, so the calculation is done there
// and result is broadcasted to other workers.
auto rc = detail::TryApplyWithLabels(ctx, std::forward<Fn>(fn));
// Broadcast the result
std::size_t size{result->Size()};
rc = std::move(rc) << [&] {
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
Expand All @@ -108,7 +112,7 @@ void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<
};
SafeColl(rc);
} else {
std::forward<Fn>(fn)();
fn();
}
}

Expand Down
1 change: 0 additions & 1 deletion src/collective/comm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <vector> // for vector

#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "comm.cuh" // for NCCLComm
#include "comm.h" // for Comm
Expand Down
1 change: 1 addition & 0 deletions src/collective/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
return channels_.at(rank);
}
[[nodiscard]] virtual bool IsFederated() const = 0;
[[nodiscard]] virtual bool IsEncrypted() const { return false; }
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;

[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
Expand Down
12 changes: 8 additions & 4 deletions src/collective/comm_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,20 @@ void Init(Json const& config) { GlobalCommGroupInit(config); }

void Finalize() { GlobalCommGroupFinalize(); }

std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
[[nodiscard]] std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }

std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
[[nodiscard]] std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }

bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
[[nodiscard]] bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }

[[nodiscard]] bool IsFederated() {
[[nodiscard]] bool IsFederated() noexcept {
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
}

[[nodiscard]] bool IsEncrypted() noexcept {
return IsFederated() && GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsEncrypted();
}

void Print(std::string const& message) {
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
SafeColl(rc);
Expand Down
2 changes: 1 addition & 1 deletion src/collective/comm_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace xgboost::collective {
*/
class CommGroup {
std::shared_ptr<HostComm> comm_;
mutable std::shared_ptr<Comm> gpu_comm_;
mutable std::shared_ptr<Comm> gpu_comm_; // lazy initialization

std::shared_ptr<Coll> backend_;
mutable std::shared_ptr<Coll> gpu_coll_; // lazy initialization
Expand Down
5 changes: 5 additions & 0 deletions src/collective/communicator-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ void Finalize();
*/
[[nodiscard]] bool IsFederated();

/**
* @brief Get if the communicator has an encryption plugin.
*/
[[nodiscard]] bool IsEncrypted();

/**
* @brief Print the message to the communicator.
*
Expand Down
16 changes: 8 additions & 8 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ bool AddCutPoint(Context const *ctx, typename SketchType::SummaryContainer const
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
auto cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
Expand Down Expand Up @@ -449,26 +449,26 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
std::int32_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (info.IsVerticalFederated() && info.IsSecure()) {
if (info.IsVerticalFederated() && collective::IsEncrypted()) {
collective::SafeColl(collective::Allreduce(ctx, &max_num_bins, collective::Op::kMax));
}
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
if (IsCat(feature_types_, fid)) {
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
// use special AddCutPoint scheme for secure vertical federated learning
bool is_nan = AddCutPoint<WQSketch>(ctx, a, max_num_bins, p_cuts, info.IsSecure());
bool is_nan = AddCutPoint<WQSketch>(ctx, a, max_num_bins, p_cuts, collective::IsEncrypted());
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!is_nan) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
const float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
const float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
} else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}
// Ensure that every feature gets at least one quantile point
Expand Down
4 changes: 1 addition & 3 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,7 @@ void MetaInfo::Validate(DeviceOrd device) const {
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)

bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}
bool MetaInfo::IsVerticalFederated() const { return collective::IsFederated() && IsColumnSplit(); }

bool MetaInfo::ShouldHaveLabels() const {
return !IsVerticalFederated() || collective::GetRank() == 0;
Expand Down
17 changes: 10 additions & 7 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class HistEvaluator {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
bool enc_vertical = is_secure_ && is_col_split_;

for (bst_bin_t i = ibegin; i != iend; i += d_step) {
// start working
Expand All @@ -305,7 +306,7 @@ class HistEvaluator {
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) - parent.root_gain);
if (!is_secure_) {
if (!enc_vertical) {
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
Expand All @@ -318,7 +319,7 @@ class HistEvaluator {
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) - parent.root_gain);
if (!is_secure_) {
if (!enc_vertical) {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
Expand Down Expand Up @@ -369,9 +370,11 @@ class HistEvaluator {
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
if ((!is_secure_) || (collective::GetRank() == 0)) {
bool is_passive_party = is_col_split_ && is_secure_ && collective::GetRank() != 0;
bool is_active_party = !is_passive_party;
if (is_active_party) {
// Evaluate the splits for each feature
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
Expand Down Expand Up @@ -410,7 +413,7 @@ class HistEvaluator {
}
});

for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
Expand Down Expand Up @@ -513,7 +516,7 @@ class HistEvaluator {
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()}{
is_secure_{collective::IsEncrypted()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down Expand Up @@ -744,7 +747,7 @@ class HistMultiEvaluator {
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()} {
is_secure_{collective::IsEncrypted()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GlobalApproxBuilder {

histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
p_fmat->Info().IsSecure(), hist_param_);
collective::IsEncrypted(), hist_param_);
monitor_->Stop(__func__);
}

Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class MultiTargetHistBuilder {
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
p_fmat->Info().IsSecure(), hist_param_);
collective::IsEncrypted(), hist_param_);

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
p_last_tree_ = p_tree;
Expand Down Expand Up @@ -355,7 +355,7 @@ class HistUpdater {
fmat->Info().IsColumnSplit());
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_);
fmat->Info().IsColumnSplit(), collective::IsEncrypted(), hist_param_);
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
p_last_tree_ = p_tree;
monitor_->Stop(__func__);
Expand Down
9 changes: 7 additions & 2 deletions tests/cpp/common/test_quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h" // for SimpleDMatrix
#include "../collective/test_worker.h" // for TestDistributedGlobal

#if defined(XGBOOST_USE_FEDERATED)
#include "../plugin/federated/test_worker.h" // for TestEncryptedGlobal
#endif // defined(XGBOOST_USE_FEDERATED)
#include "xgboost/context.h"

namespace xgboost::common {
Expand Down Expand Up @@ -310,6 +314,7 @@ void DoTestColSplitQuantileSecure() {
Context ctx;
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
ASSERT_TRUE(collective::IsEncrypted());
size_t cols = 2;
size_t rows = 3;

Expand All @@ -336,7 +341,7 @@ void DoTestColSplitQuantileSecure() {

auto const n_bins = 64;

m->Info().data_split_mode = DataSplitMode::kColSecure;
m->Info().data_split_mode = DataSplitMode::kCol;
// Generate cuts for distributed environment.
HistogramCuts distributed_cuts;
{
Expand Down Expand Up @@ -392,7 +397,7 @@ void DoTestColSplitQuantileSecure() {
template <bool use_column>
void TestColSplitQuantileSecure() {
auto constexpr kWorkers = 2;
collective::TestFederatedGlobal(kWorkers, [] { DoTestColSplitQuantileSecure<use_column>(); });
collective::TestEncryptedGlobal(kWorkers, [&] { DoTestColSplitQuantileSecure<use_column>(); });
}
#endif // defined(XGBOOST_USE_FEDERATED)
} // anonymous namespace
Expand Down
Loading

0 comments on commit 733af8d

Please sign in to comment.