Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement secure boost scheme phase 1 - vertical pipeline with hist sync #10037

Merged
merged 27 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8570ba5
Add additional data split mode to cover the secure vertical pipeline
ZiyueXu77 Jan 31, 2024
2d00db6
Add IsSecure info and update corresponding functions
ZiyueXu77 Jan 31, 2024
ab17f5a
Modify evaluate_splits to block non-label owners to perform hist comp…
ZiyueXu77 Jan 31, 2024
fb1787c
Continue using Allgather for best split sync for secure vertical, equ…
ZiyueXu77 Feb 2, 2024
7a2a2b8
Modify histogram sync scheme for secure vertical case, can identify g…
ZiyueXu77 Feb 6, 2024
3ca3142
Sync cut informaiton across clients, full pipeline works for testing …
ZiyueXu77 Feb 7, 2024
22dd522
Code cleanup, phase 1 of alternative vertical pipeline finished
ZiyueXu77 Feb 8, 2024
52e8951
Code clean
ZiyueXu77 Feb 8, 2024
e9eef15
change kColS to kColSecure to avoid confusion with kCols
ZiyueXu77 Feb 12, 2024
70e6ca6
Add additional data split mode to cover the secure vertical pipeline
ZiyueXu77 Jan 31, 2024
a54ea6a
Add IsSecure info and update corresponding functions
ZiyueXu77 Jan 31, 2024
6fe61dd
Modify evaluate_splits to block non-label owners to perform hist comp…
ZiyueXu77 Jan 31, 2024
1c2b7ed
Continue using Allgather for best split sync for secure vertical, equ…
ZiyueXu77 Feb 2, 2024
b36ff2b
Modify histogram sync scheme for secure vertical case, can identify g…
ZiyueXu77 Feb 6, 2024
0707731
Sync cut informaiton across clients, full pipeline works for testing …
ZiyueXu77 Feb 7, 2024
dce7609
Code cleanup, phase 1 of alternative vertical pipeline finished
ZiyueXu77 Feb 8, 2024
6cebc31
Code clean
ZiyueXu77 Feb 8, 2024
1562f52
change kColS to kColSecure to avoid confusion with kCols
ZiyueXu77 Feb 12, 2024
f31c824
Add one unit test
YuanTingHsieh Feb 17, 2024
6fcbe02
Merge branch 'SecureBoost' into add_alternate_vertical_splits
ZiyueXu77 Feb 20, 2024
967e307
Merge pull request #1 from YuanTingHsieh/add_alternate_vertical_splits
ZiyueXu77 Feb 20, 2024
04cd1cb
Merge branch 'dmlc:master' into SecureBoost
ZiyueXu77 Feb 20, 2024
087a8dd
Merge branch 'dmlc:master' into SecureBoost
ZiyueXu77 Feb 23, 2024
616f68e
Merge branch 'vertical-federated-learning' into SecureBoost
ZiyueXu77 Feb 28, 2024
7e407a8
remove redundant print
ZiyueXu77 Feb 28, 2024
add5dcd
updates according to comments
ZiyueXu77 Feb 28, 2024
7d4b99d
fix linting issues
ZiyueXu77 Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1 };
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -186,7 +186,11 @@ class MetaInfo {
}

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol) || (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }

Expand Down
52 changes: 51 additions & 1 deletion src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,39 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
}
}

template <typename SketchType>
void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
// For secure vertical pipeline, we fill the cut values corresponding to empty columns
// with a vector of minimum value
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
const float mval = 1e-5f;
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
// Sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);

// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// if not empty column, fill the cut values with the actual values
if (required_cuts_original > 0) {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
}
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
// if empty column, fill the cut values with 0
else {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}


auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) {
InvalidCategory();
Expand Down Expand Up @@ -415,11 +448,21 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
float max_cat{-1.f};
for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
if (info.IsVerticalFederated() && info.IsSecure()) {
collective::Allreduce<collective::Operation::kMax>(&max_num_bins, 1);
}
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
if (IsCat(feature_types_, fid)) {
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
// use special AddCutPoint scheme for secure vertical federated learning
if (info.IsVerticalFederated() && info.IsSecure()) {
AddCutPointSecure<WQSketch>(a, max_num_bins, p_cuts);
}
else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
}
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
Expand All @@ -435,6 +478,13 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
108 changes: 65 additions & 43 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class HistEvaluator {
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
bool is_col_split_{false};
bool is_secure_{false};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;

Expand Down Expand Up @@ -280,7 +281,6 @@ class HistEvaluator {
}
}
}

p_best->Update(best);
return left_sum;
}
Expand Down Expand Up @@ -346,57 +346,76 @@ class HistEvaluator {
auto evaluator = tree_evaluator_.GetEvaluator();
auto const& cut_ptrs = cut.Ptrs();

common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
// Under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will only receive the final best split information
// Hence the below computation is not performed by the non-label owners under secure vertical setting
if ((!is_secure_) || (collective::GetRank() == 0)) {
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
// Evaluate the splits for each feature
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
} else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
}
else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
}
else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);

// print the best split for each feature
// std::cout << "Best split for feature " << fidx << " is " << best->split_value << " with gain " << best->loss_chg << std::endl;

ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved

if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
}
}
}
}
});
});

for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(
tloc_candidates[n_threads * nidx_in_set + tidx].split);
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(
tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}

if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
// Note that under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will receive the final best splits
// allgather is capable of performing this (0-gain entries for non-label owners),
// but can be replaced with a broadcast in the future

auto all_entries = Allgather(entries);

for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(
Expand Down Expand Up @@ -477,7 +496,8 @@ class HistEvaluator {
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
is_col_split_{info.IsColumnSplit()} {
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()}{
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand All @@ -493,6 +513,7 @@ class HistMultiEvaluator {
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;
bool is_col_split_{false};
bool is_secure_{false};

private:
static double MultiCalcSplitGain(TrainParam const &param,
Expand Down Expand Up @@ -753,7 +774,8 @@ class HistMultiEvaluator {
: param_{param},
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()} {
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
20 changes: 17 additions & 3 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class HistogramBuilder {
// Whether XGBoost is running in distributed environment.
bool is_distributed_{false};
bool is_col_split_{false};
bool is_secure_{false};

public:
/**
Expand All @@ -60,13 +61,14 @@ class HistogramBuilder {
* of using global rabit variable.
*/
void Reset(Context const *ctx, bst_bin_t total_bins, BatchParam const &p, bool is_distributed,
bool is_col_split, HistMakerTrainParam const *param) {
bool is_col_split, bool is_secure, HistMakerTrainParam const *param) {
n_threads_ = ctx->Threads();
param_ = p;
hist_.Reset(total_bins, param->max_cached_hist_node);
buffer_.Init(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
is_secure_ = is_secure;
}

template <bool any_missing>
Expand Down Expand Up @@ -175,6 +177,7 @@ class HistogramBuilder {
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick) {
auto n_total_bins = buffer_.TotalBins();

common::BlockedSpace2d space(
nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024);
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
Expand All @@ -190,6 +193,17 @@ class HistogramBuilder {
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
}

if (is_distributed_ && is_col_split_ && is_secure_) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to allgather the histogram across workers? I thought we only need to send it to the active worker?

Copy link
Author

@ZiyueXu77 ZiyueXu77 Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we only need to collect histograms to the active party, but my understanding is we currently do not have a "gather" function to do that? it will be great if we have it, similar to broadcast(..., rank), just reverse

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing, I can look into a gather function in the future.

// Under secure vertical mode, we perform allgather for all nodes
CHECK(!nodes_to_build.empty());
// in theory the operation is AllGather, but with current system functionality,
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
// we use AllReduce to simulate the AllGather operation
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
}

common::BlockedSpace2d const &subspace =
nodes_to_trick.size() == nodes_to_build.size()
? space
Expand Down Expand Up @@ -329,12 +343,12 @@ class MultiHistogramBuilder {
[[nodiscard]] auto &Histogram(bst_target_t t) { return target_builders_[t].Histogram(); }

void Reset(Context const *ctx, bst_bin_t total_bins, bst_target_t n_targets, BatchParam const &p,
bool is_distributed, bool is_col_split, HistMakerTrainParam const *param) {
bool is_distributed, bool is_col_split, bool is_secure, HistMakerTrainParam const *param) {
ctx_ = ctx;
target_builders_.resize(n_targets);
CHECK_GE(n_targets, 1);
for (auto &v : target_builders_) {
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, param);
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, is_secure, param);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class GloablApproxBuilder {
}

histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure(),
hist_param_);
monitor_->Stop(__func__);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class MultiTargetHistBuilder {
bst_target_t n_targets = p_tree->NumTargets();
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure(),
hist_param_);

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
Expand Down Expand Up @@ -358,7 +358,7 @@ class HistUpdater {
fmat->Info().IsColumnSplit());
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), hist_param_);
fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_);
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
p_last_tree_ = p_tree;
monitor_->Stop(__func__);
Expand Down
Loading