diff --git a/include/xgboost/data.h b/include/xgboost/data.h index f827e4758932..332f49264a64 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -39,7 +39,7 @@ enum class DataType : uint8_t { enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; -enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 }; +enum class DataSplitMode : int { kRow = 0, kCol = 1 }; /*! * \brief Meta information about dataset, always sit in memory. @@ -174,17 +174,11 @@ class MetaInfo { */ void SynchronizeNumberOfColumns(Context const* ctx); - /*! \brief Whether the data is split row-wise. */ - bool IsRowSplit() const { - return data_split_mode == DataSplitMode::kRow; - } + /** @brief Whether the data is split row-wise. */ + [[nodiscard]] bool IsRowSplit() const { return data_split_mode == DataSplitMode::kRow; } /** @brief Whether the data is split column-wise. */ - bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol) - || (data_split_mode == DataSplitMode::kColSecure); } - - /** @brief Whether the data is split column-wise with secure computation. */ - bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; } + [[nodiscard]] bool IsColumnSplit() const { return !this->IsRowSplit(); } /** @brief Whether this is a learning to rank data. */ bool IsRanking() const { return !group_ptr_.empty(); } diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc index 6fa19112878a..c52822836f43 100644 --- a/plugin/federated/federated_comm.cc +++ b/plugin/federated/federated_comm.cc @@ -8,6 +8,7 @@ #include // for int32_t #include // for getenv #include // for numeric_limits +#include // for make_shared #include // for string, stoi #include "../../src/common/common.h" // for Split @@ -32,7 +33,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_ CHECK_LT(rank, world) << "Invalid worker rank."; auto certs = {server_cert, client_cert, client_cert}; - auto is_empty = [](auto const& s) { return s.empty(); }; + auto is_empty = [](auto const& s) { + return s.empty(); + }; bool valid = std::all_of(certs.begin(), certs.end(), is_empty) || std::none_of(certs.begin(), certs.end(), is_empty); CHECK(valid) << "Invalid arguments for certificates."; @@ -123,6 +126,11 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s client_key = OptionalArg(config, "federated_client_key_path", client_key); client_cert = OptionalArg(config, "federated_client_cert_path", client_cert); + /** + * Hist encryption plugin. + */ + this->plugin_.reset(CreateFederatedPlugin(config)); + this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key, client_cert); } diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index 0909509e07bc..8004ca4d66fd 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -11,12 +11,15 @@ #include // for shared_ptr #include // for string -#include "../../src/collective/comm.h" // for HostComm +#include "../../src/collective/comm.h" // for HostComm +#include "federated_plugin.h" // for FederatedPlugin #include "xgboost/json.h" namespace xgboost::collective { class FederatedComm : public HostComm { std::shared_ptr stub_; + // Plugin for encryption + std::shared_ptr plugin_{nullptr}; void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank, std::string const& server_cert, std::string const& client_key, @@ -62,6 +65,7 @@ class FederatedComm : public HostComm { return Success(); } [[nodiscard]] bool IsFederated() const override { return true; } + [[nodiscard]] bool IsEncrypted() const override { return static_cast(plugin_); } [[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); } [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; @@ -73,5 +77,7 @@ class FederatedComm : public HostComm { *out = "rank:" + std::to_string(rank); return Success(); }; + + auto EncryptionPlugin() const { return plugin_; } }; } // namespace xgboost::collective diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index 2ceb35821cb9..809bbadd5a0f 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -16,9 +16,13 @@ #include "communicator-inl.h" #include "xgboost/collective/result.h" // for Result #include "xgboost/data.h" // for MetaINfo +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_comm.h" +#endif // defined(XGBOOST_USE_FEDERATED) namespace xgboost::collective { namespace detail { +// Apply function fn, and handle potential errors. template [[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) { std::string msg; @@ -29,10 +33,10 @@ template msg = e.what(); } } + // Error handling std::size_t msg_size{msg.size()}; auto rc = Success() << [&] { - auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0); - return rc; + return collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0); } << [&] { if (msg_size > 0) { msg.resize(msg_size); @@ -95,10 +99,10 @@ template void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector* result, Fn&& fn) { if (info.IsVerticalFederated()) { - // We assume labels are only available on worker 0, so the calculation is done there and result - // broadcast to other workers. - auto rc = detail::TryApplyWithLabels(ctx, fn); - + // We assume labels are only available on worker 0, so the calculation is done there + // and result is broadcasted to other workers. + auto rc = detail::TryApplyWithLabels(ctx, std::forward(fn)); + // Broadcast the result std::size_t size{result->Size()}; rc = std::move(rc) << [&] { return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0); @@ -108,7 +112,7 @@ void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector< }; SafeColl(rc); } else { - std::forward(fn)(); + fn(); } } diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 6566f28fad91..2ca7b76898fd 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -11,7 +11,6 @@ #include // for vector #include "../common/cuda_context.cuh" // for CUDAContext -#include "../common/device_helpers.cuh" // for DefaultStream #include "../common/type.h" // for EraseType #include "comm.cuh" // for NCCLComm #include "comm.h" // for Comm diff --git a/src/collective/comm.h b/src/collective/comm.h index 72fec2e816e9..4f436ac3ac9c 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -102,6 +102,7 @@ class Comm : public std::enable_shared_from_this { return channels_.at(rank); } [[nodiscard]] virtual bool IsFederated() const = 0; + [[nodiscard]] virtual bool IsEncrypted() const { return false; } [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index a9b58ecb5505..5321342c7453 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -128,16 +128,20 @@ void Init(Json const& config) { GlobalCommGroupInit(config); } void Finalize() { GlobalCommGroupFinalize(); } -std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); } +[[nodiscard]] std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); } -std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); } +[[nodiscard]] std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); } -bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); } +[[nodiscard]] bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); } -[[nodiscard]] bool IsFederated() { +[[nodiscard]] bool IsFederated() noexcept { return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated(); } +[[nodiscard]] bool IsEncrypted() noexcept { + return IsFederated() && GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsEncrypted(); +} + void Print(std::string const& message) { auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message); SafeColl(rc); diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index a98de0c16e51..80c4c348c947 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -17,7 +17,7 @@ namespace xgboost::collective { */ class CommGroup { std::shared_ptr comm_; - mutable std::shared_ptr gpu_comm_; + mutable std::shared_ptr gpu_comm_; // lazy initialization std::shared_ptr backend_; mutable std::shared_ptr gpu_coll_; // lazy initialization diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 2632007009ed..2883a7607ff0 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -47,6 +47,11 @@ void Finalize(); */ [[nodiscard]] bool IsFederated(); +/** + * @brief Get if the communicator has an encryption plugin. + */ +[[nodiscard]] bool IsEncrypted(); + /** * @brief Print the message to the communicator. * diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 0573d1f5f526..abfe3c8412a2 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -389,7 +389,7 @@ bool AddCutPoint(Context const *ctx, typename SketchType::SummaryContainer const } else { // we use the min_value as the first (0th) element, hence starting from 1. for (size_t i = 1; i < required_cuts; ++i) { - bst_float cpt = summary.data[i].value; + auto cpt = summary.data[i].value; if (i == 1 || cpt > cut_values.back()) { cut_values.push_back(cpt); } @@ -449,7 +449,7 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const std::int32_t max_num_bins = std::min(num_cuts[fid], max_bins_); // If vertical and secure mode, we need to sync the max_num_bins aross workers // to create the same global number of cut point bins for easier future processing - if (info.IsVerticalFederated() && info.IsSecure()) { + if (info.IsVerticalFederated() && collective::IsEncrypted()) { collective::SafeColl(collective::Allreduce(ctx, &max_num_bins, collective::Op::kMax)); } typename WQSketch::SummaryContainer const &a = final_summaries[fid]; @@ -457,18 +457,18 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts)); } else { // use special AddCutPoint scheme for secure vertical federated learning - bool is_nan = AddCutPoint(ctx, a, max_num_bins, p_cuts, info.IsSecure()); + bool is_nan = AddCutPoint(ctx, a, max_num_bins, p_cuts, collective::IsEncrypted()); // push a value that is greater than anything if the feature is not empty // i.e. if the last value is not NaN if (!is_nan) { - const bst_float cpt = - (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; + const float cpt = + (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; // this must be bigger than last value in a scale - const bst_float last = cpt + (fabs(cpt) + 1e-5f); + const float last = cpt + (fabs(cpt) + 1e-5f); p_cuts->cut_values_.HostVector().push_back(last); } else { - // if the feature is empty, push a NaN value - p_cuts->cut_values_.HostVector().push_back(std::numeric_limits::quiet_NaN()); + // if the feature is empty, push a NaN value + p_cuts->cut_values_.HostVector().push_back(std::numeric_limits::quiet_NaN()); } } // Ensure that every feature gets at least one quantile point diff --git a/src/data/data.cc b/src/data/data.cc index 3f9c13fa5053..33b33e6931bf 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -803,9 +803,7 @@ void MetaInfo::Validate(DeviceOrd device) const { void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) -bool MetaInfo::IsVerticalFederated() const { - return collective::IsFederated() && IsColumnSplit(); -} +bool MetaInfo::IsVerticalFederated() const { return collective::IsFederated() && IsColumnSplit(); } bool MetaInfo::ShouldHaveLabels() const { return !IsVerticalFederated() || collective::GetRank() == 0; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 654c3c6627f3..cc312d31ce35 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -291,6 +291,7 @@ class HistEvaluator { ibegin = static_cast(cut_ptr[fidx + 1]) - 1; iend = static_cast(cut_ptr[fidx]) - 1; } + bool enc_vertical = is_secure_ && is_col_split_; for (bst_bin_t i = ibegin; i != iend; i += d_step) { // start working @@ -305,7 +306,7 @@ class HistEvaluator { loss_chg = static_cast(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - parent.root_gain); - if (!is_secure_) { + if (!enc_vertical) { split_pt = cut_val[i]; // not used for partition based best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); } else { @@ -318,7 +319,7 @@ class HistEvaluator { loss_chg = static_cast(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum}, GradStats{left_sum}) - parent.root_gain); - if (!is_secure_) { + if (!enc_vertical) { if (i == imin) { split_pt = cut.MinValues()[fidx]; } else { @@ -369,9 +370,11 @@ class HistEvaluator { // Under secure vertical setting, only the active party is able to evaluate the split // based on global histogram. Other parties will receive the final best split information // Hence the below computation is not performed by the passive parties - if ((!is_secure_) || (collective::GetRank() == 0)) { + bool is_passive_party = is_col_split_ && is_secure_ && collective::GetRank() != 0; + bool is_active_party = !is_passive_party; + if (is_active_party) { // Evaluate the splits for each feature - common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) { + common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) { auto tidx = omp_get_thread_num(); auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx]; auto best = &entry->split; @@ -410,7 +413,7 @@ class HistEvaluator { } }); - for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { + for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { for (auto tidx = 0; tidx < n_threads; ++tidx) { entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split); } @@ -513,7 +516,7 @@ class HistEvaluator { column_sampler_{std::move(sampler)}, tree_evaluator_{*param, static_cast(info.num_col_), DeviceOrd::CPU()}, is_col_split_{info.IsColumnSplit()}, - is_secure_{info.IsSecure()}{ + is_secure_{collective::IsEncrypted()} { interaction_constraints_.Configure(*param, info.num_col_); column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(), param_->colsample_bynode, param_->colsample_bylevel, @@ -744,7 +747,7 @@ class HistMultiEvaluator { column_sampler_{std::move(sampler)}, ctx_{ctx}, is_col_split_{info.IsColumnSplit()}, - is_secure_{info.IsSecure()} { + is_secure_{collective::IsEncrypted()} { interaction_constraints_.Configure(*param, info.num_col_); column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(), param_->colsample_bynode, param_->colsample_bylevel, diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index c79df741d3aa..967dea358e84 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -93,7 +93,7 @@ class GlobalApproxBuilder { histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess), collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), - p_fmat->Info().IsSecure(), hist_param_); + collective::IsEncrypted(), hist_param_); monitor_->Stop(__func__); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 3731ded677d7..bcef0326a5f6 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -165,7 +165,7 @@ class MultiTargetHistBuilder { histogram_builder_ = std::make_unique(); histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_), collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), - p_fmat->Info().IsSecure(), hist_param_); + collective::IsEncrypted(), hist_param_); evaluator_ = std::make_unique(ctx_, p_fmat->Info(), param_, col_sampler_); p_last_tree_ = p_tree; @@ -355,7 +355,7 @@ class HistUpdater { fmat->Info().IsColumnSplit()); } histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(), - fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_); + fmat->Info().IsColumnSplit(), collective::IsEncrypted(), hist_param_); evaluator_ = std::make_unique(ctx_, this->param_, fmat->Info(), col_sampler_); p_last_tree_ = p_tree; monitor_->Stop(__func__); diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 402eeda9b8c4..e08e3b7d5f1c 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -12,6 +12,10 @@ #include "../../../src/data/adapter.h" #include "../../../src/data/simple_dmatrix.h" // for SimpleDMatrix #include "../collective/test_worker.h" // for TestDistributedGlobal + +#if defined(XGBOOST_USE_FEDERATED) +#include "../plugin/federated/test_worker.h" // for TestEncryptedGlobal +#endif // defined(XGBOOST_USE_FEDERATED) #include "xgboost/context.h" namespace xgboost::common { @@ -310,6 +314,7 @@ void DoTestColSplitQuantileSecure() { Context ctx; auto const world = collective::GetWorldSize(); auto const rank = collective::GetRank(); + ASSERT_TRUE(collective::IsEncrypted()); size_t cols = 2; size_t rows = 3; @@ -336,7 +341,7 @@ void DoTestColSplitQuantileSecure() { auto const n_bins = 64; - m->Info().data_split_mode = DataSplitMode::kColSecure; + m->Info().data_split_mode = DataSplitMode::kCol; // Generate cuts for distributed environment. HistogramCuts distributed_cuts; { @@ -392,7 +397,7 @@ void DoTestColSplitQuantileSecure() { template void TestColSplitQuantileSecure() { auto constexpr kWorkers = 2; - collective::TestFederatedGlobal(kWorkers, [] { DoTestColSplitQuantileSecure(); }); + collective::TestEncryptedGlobal(kWorkers, [&] { DoTestColSplitQuantileSecure(); }); } #endif // defined(XGBOOST_USE_FEDERATED) } // anonymous namespace diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h index 5f53965da5dd..f9b030294b99 100644 --- a/tests/cpp/plugin/federated/test_worker.h +++ b/tests/cpp/plugin/federated/test_worker.h @@ -83,4 +83,16 @@ void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) { collective::Finalize(); }); } + +template +void TestEncryptedGlobal(std::int32_t n_workers, WorkerFn&& fn) { + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + config["federated_plugin"] = Object{}; + config["federated_plugin"]["name"] = String{"mock"}; + collective::Init(config); + fn(); + collective::Finalize(); + }); +} } // namespace xgboost::collective diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index fffbba4d7efc..1ac60bbe8268 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -307,7 +307,7 @@ void DoTestEvaluateSplitsSecure(bool force_read_by_column) { auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); auto m = dmat->SliceCol(world, rank); - m->Info().data_split_mode = DataSplitMode::kColSecure; + m->Info().data_split_mode = DataSplitMode::kCol; auto evaluator = HistEvaluator{&ctx, ¶m, m->Info(), sampler}; BoundedHistCollection hist; diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 87ea1007f348..9d13c47aaa02 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -565,7 +565,7 @@ class OverflowTest : public ::testing::TestWithParam> { CHECK_EQ(Xy->Info().IsColumnSplit(), is_col_split); hist_builder.Reset(&ctx, n_total_bins, tree.NumTargets(), batch, is_distributed, - Xy->Info().IsColumnSplit(), Xy->Info().IsSecure(), &hist_param); + Xy->Info().IsColumnSplit(), collective::IsEncrypted(), &hist_param); std::vector partitioners; partitioners.emplace_back(&ctx, Xy->Info().num_row_, /*base_rowid=*/0,