Skip to content

Commit

Permalink
[secure boost] Vertical pipeline with hist sync (#10037) (#10528)
Browse files Browse the repository at this point in the history
The first phase is to implement an alternative vertical pipeline that syncs the histograms from clients to the label owner.

Co-authored-by: Ziyue Xu <71786575+ZiyueXu77@users.noreply.github.com>
  • Loading branch information
trivialfis and ZiyueXu77 authored Jul 2, 2024
1 parent a39fef2 commit a716334
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 87 deletions.
9 changes: 7 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1 };
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -180,7 +180,12 @@ class MetaInfo {
}

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }

Expand Down
61 changes: 44 additions & 17 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,16 +368,32 @@ void SketchContainerImpl<WQSketch>::AllReduce(
}

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
bool AddCutPoint(Context const *ctx, typename SketchType::SummaryContainer const &summary,
int max_bin, HistogramCuts *cuts, bool secure) {
bst_idx_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// sync the required_cuts across all workers
collective::SafeColl(collective::Allreduce(ctx, &required_cuts, collective::Op::kMax));
}
// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
// if secure and empty column, fill the cut values with NaN
if (secure && (required_cuts_original == 0)) {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(std::numeric_limits<double>::quiet_NaN());
}
return true;
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
return false;
}
}

Expand Down Expand Up @@ -429,20 +445,31 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const

float max_cat{-1.f};
for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
std::int32_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (info.IsVerticalFederated() && info.IsSecure()) {
collective::SafeColl(collective::Allreduce(ctx, &max_num_bins, collective::Op::kMax));
}
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
if (IsCat(feature_types_, fid)) {
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
// use special AddCutPoint scheme for secure vertical federated learning
bool is_nan = AddCutPoint<WQSketch>(ctx, a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!is_nan) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
} else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}

// Ensure that every feature gets at least one quantile point
CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(p_cuts->cut_values_.HostVector().size());
Expand Down
94 changes: 53 additions & 41 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class HistEvaluator {
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
bool is_col_split_{false};
bool is_secure_{false};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;

Expand Down Expand Up @@ -322,7 +323,6 @@ class HistEvaluator {
}
}
}

p_best->Update(best);
return left_sum;
}
Expand Down Expand Up @@ -354,54 +354,63 @@ class HistEvaluator {
auto evaluator = tree_evaluator_.GetEvaluator();
auto const &cut_ptrs = cut.Ptrs();

common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
if ((!is_secure_) || (collective::GetRank() == 0)) {
// Evaluate the splits for each feature
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
} else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
} else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
}
}
}
}
});
});

for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}

if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
// Note that under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will receive the final best splits
// allgather is capable of performing this (0-gain entries for non-label owners),
auto all_entries = AllgatherColumnSplit(ctx_, entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
Expand Down Expand Up @@ -481,7 +490,8 @@ class HistEvaluator {
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
is_col_split_{info.IsColumnSplit()} {
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()}{
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand All @@ -497,6 +507,7 @@ class HistMultiEvaluator {
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;
bool is_col_split_{false};
bool is_secure_{false};

private:
static double MultiCalcSplitGain(TrainParam const &param,
Expand Down Expand Up @@ -710,7 +721,8 @@ class HistMultiEvaluator {
: param_{param},
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()} {
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
25 changes: 21 additions & 4 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class HistogramBuilder {
// Whether XGBoost is running in distributed environment.
bool is_distributed_{false};
bool is_col_split_{false};
bool is_secure_{false};

public:
/**
Expand All @@ -58,13 +59,14 @@ class HistogramBuilder {
* of using global rabit variable.
*/
void Reset(Context const *ctx, bst_bin_t total_bins, BatchParam const &p, bool is_distributed,
bool is_col_split, HistMakerTrainParam const *param) {
bool is_col_split, bool is_secure, HistMakerTrainParam const *param) {
n_threads_ = ctx->Threads();
param_ = p;
hist_.Reset(total_bins, param->max_cached_hist_node);
buffer_.Init(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
is_secure_ = is_secure;
}

template <bool any_missing>
Expand Down Expand Up @@ -169,10 +171,11 @@ class HistogramBuilder {
}
}

void SyncHistogram(Context const *ctx, RegTree const *p_tree,
void SyncHistogram(Context const *ctx, RegTree const *p_tree,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick) {
auto n_total_bins = buffer_.TotalBins();

common::BlockedSpace2d space(
nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024);
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
Expand All @@ -190,6 +193,19 @@ class HistogramBuilder {
SafeColl(rc);
}

if (is_distributed_ && is_col_split_ && is_secure_) {
// Under secure vertical mode, we perform allgather for all nodes
CHECK(!nodes_to_build.empty());
// in theory the operation is AllGather, under current histogram setting of
// same length with 0s for empty slots,
// AllReduce is the most efficient way of achieving the global histogram
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::SafeColl(collective::Allreduce(
ctx, linalg::MakeVec(reinterpret_cast<double *>(this->hist_[first_nidx].data()), n),
collective::Op::kSum));
}

common::BlockedSpace2d const &subspace =
nodes_to_trick.size() == nodes_to_build.size()
? space
Expand Down Expand Up @@ -329,12 +345,13 @@ class MultiHistogramBuilder {
[[nodiscard]] auto &Histogram(bst_target_t t) { return target_builders_[t].Histogram(); }

void Reset(Context const *ctx, bst_bin_t total_bins, bst_target_t n_targets, BatchParam const &p,
bool is_distributed, bool is_col_split, HistMakerTrainParam const *param) {
bool is_distributed, bool is_col_split, bool is_secure,
HistMakerTrainParam const *param) {
ctx_ = ctx;
target_builders_.resize(n_targets);
CHECK_GE(n_targets, 1);
for (auto &v : target_builders_) {
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, param);
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, is_secure, param);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GlobalApproxBuilder {

histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
hist_param_);
p_fmat->Info().IsSecure(), hist_param_);
monitor_->Stop(__func__);
}

Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class MultiTargetHistBuilder {
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
hist_param_);
p_fmat->Info().IsSecure(), hist_param_);

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
p_last_tree_ = p_tree;
Expand Down Expand Up @@ -355,7 +355,7 @@ class HistUpdater {
fmat->Info().IsColumnSplit());
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), hist_param_);
fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_);
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
p_last_tree_ = p_tree;
monitor_->Stop(__func__);
Expand Down
Loading

0 comments on commit a716334

Please sign in to comment.