Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Secure vertical federated scheme for GPU computation #10652

Merged
merged 31 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
aad29df
secure horizontal for GPU
ZiyueXu77 Jul 17, 2024
b448dff
more structured refinement
ZiyueXu77 Jul 17, 2024
7d6d592
add the missing aggr piece to secure hori GPU
ZiyueXu77 Jul 18, 2024
7421cfa
code lint corrections
ZiyueXu77 Jul 19, 2024
32b704c
Merge branch 'dmlc:federated-secure' into federated-secure
ZiyueXu77 Jul 23, 2024
1efa656
Merge branch 'dmlc:federated-secure' into federated-secure
ZiyueXu77 Jul 24, 2024
361e17a
remove redundant plugin init
ZiyueXu77 Jul 25, 2024
306fc1a
update to avoid pipeline stuck for secure vertical
ZiyueXu77 Jul 25, 2024
906b0fa
update to avoid pipeline stuck for secure vertical
ZiyueXu77 Jul 25, 2024
35d23f8
Update src/tree/updater_gpu_hist.cu
ZiyueXu77 Jul 26, 2024
9ab2d8d
Update src/tree/updater_gpu_hist.cu
ZiyueXu77 Jul 26, 2024
96ab113
wrap cuda API calls in safe_cuda
ZiyueXu77 Jul 26, 2024
778d6be
add include for federatedcomm
ZiyueXu77 Jul 26, 2024
32c3014
add include conditions
ZiyueXu77 Jul 26, 2024
7df5955
add include conditions
ZiyueXu77 Jul 26, 2024
3cc863a
correct import error
ZiyueXu77 Jul 29, 2024
f2b876d
implement alternative vertical pipeline in GPU
ZiyueXu77 Jul 30, 2024
d14ab8a
Merge branch 'dmlc:federated-secure' into federated-secure
ZiyueXu77 Jul 30, 2024
2572519
transmit necessary info to plugin - align GPU with CPU
ZiyueXu77 Jul 31, 2024
e42faaa
marked calls to plugin - align GPU with CPU
ZiyueXu77 Jul 31, 2024
26aaded
secure vertical GPU fully functional
ZiyueXu77 Jul 31, 2024
aa5b51b
fix code linting and test scripts
ZiyueXu77 Aug 1, 2024
7480ed3
wrap plugin calls into federated
ZiyueXu77 Aug 1, 2024
4587b2e
only rank 0 need histogram sync result
ZiyueXu77 Aug 1, 2024
ad21314
Update histogram.cu
ZiyueXu77 Aug 1, 2024
397ade3
Added check for passive when sync histo for vertical and removed some…
nvidianz Aug 1, 2024
8c6d459
Merge pull request #11 from nvidianz/fix-passive-sync-error
ZiyueXu77 Aug 1, 2024
61d0821
Code clean
ZiyueXu77 Aug 1, 2024
563ed7b
Merge branch 'federated-secure' into federated-secure
ZiyueXu77 Aug 5, 2024
ba82521
updates for PR checks
ZiyueXu77 Aug 9, 2024
45b4e2e
updates for PR checks
ZiyueXu77 Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ void BroadcastGradient(Context const* ctx, MetaInfo const& info, GradFn&& grad_f
SafeColl(rc);
// Pass the gradient to the plugin
fed.EncryptionPlugin()->SyncEncryptedGradient(encrypted);

// !!!Temporarily solution
// This step is needed for memory allocation in the case of vertical secure GPU
// make out_gpair data value to all zero to avoid information leak
auto gpair_data = out_gpair->Data();
gpair_data->Fill(GradientPair{0.0f, 0.0f});
ApplyWithLabels(ctx, info, gpair_data, [&] { grad_fn(out_gpair); });
#else
LOG(FATAL) << error::NoFederated();
#endif
Expand Down
23 changes: 23 additions & 0 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
}
}

auto secure_vertical = is_column_split && collective::IsEncrypted();
// Set up output cuts
for (bst_feature_t i = 0; i < num_columns_; ++i) {
size_t column_size = std::max(static_cast<size_t>(1ul), this->Column(i).size());
Expand All @@ -681,6 +682,11 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
CheckMaxCat(max_values[i].value, column_size);
h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0.
} else {
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (secure_vertical) {
collective::SafeColl(collective::Allreduce(ctx, &column_size, collective::Op::kMax));
}
h_out_columns_ptr.push_back(
std::min(static_cast<size_t>(column_size), static_cast<size_t>(num_bins_)));
}
Expand Down Expand Up @@ -711,6 +717,10 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
out_column[0] = kRtEps;
assert(out_column.size() == 1);
}
// For secure vertical split, fill all cut values with dummy value
if (secure_vertical) {
out_column[idx] = kRtEps;
}
return;
}

Expand All @@ -736,6 +746,19 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
out_column[idx] = in_column[idx+1].value;
});

if (secure_vertical) {
// cut values need to be synced across all workers via Allreduce
// To do: apply same inference indexing as CPU, skip for now
auto cut_values_device = p_cuts->cut_values_.DeviceSpan();
std::vector<float> cut_values_host(cut_values_device.size());
dh::CopyDeviceSpanToVector(&cut_values_host, cut_values_device);
auto rc = collective::Allreduce(ctx, &cut_values_host, collective::Op::kSum);
SafeColl(rc);
dh::safe_cuda(cudaMemcpyAsync(cut_values_device.data(), cut_values_host.data(),
cut_values_device.size() * sizeof(float),
cudaMemcpyHostToDevice));
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
timer_.Stop(__func__);
}
Expand Down
67 changes: 42 additions & 25 deletions src/tree/gpu_hist/evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <limits>

#include "../../collective/allgather.h"
#include "../../collective/broadcast.h"
#include "../../common/categorical.h"
#include "../../data/ellpack_page.cuh"
#include "evaluate_splits.cuh"
Expand Down Expand Up @@ -404,34 +405,50 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_

dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(d_inputs.size());
auto out_splits = dh::ToSpan(splits_out_storage);
this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs,
evaluator, out_splits);

if (is_column_split_) {
// With column-wise data split, we gather the split candidates from all the workers and find the
// global best candidates.
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
auto current_rank =
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(),
out_splits.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToDevice));
auto rc = collective::Allgather(
ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device()));
collective::SafeColl(rc);

// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),
[world_size, all_candidates, out_splits] __device__(size_t i) {
out_splits[i] = all_candidates[i];
for (auto rank = 1; rank < world_size; rank++) {
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
}
});
bool is_passive_party = is_column_split_ && collective::IsEncrypted()
&& collective::GetRank() != 0;
bool is_active_party = !is_passive_party;
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
if (is_active_party) {
this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs,
evaluator, out_splits);
}

if (is_column_split_) {
if (!collective::IsEncrypted()) {
// With regular column-wise data split, we gather the split candidates from
// all the workers and find the global best candidates.
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(
out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
auto current_rank =
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(),
out_splits.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToDevice));
auto rc = collective::Allgather(
ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device()));
collective::SafeColl(rc);
// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),
[world_size, all_candidates, out_splits] __device__(size_t i) {
out_splits[i] = all_candidates[i];
for (auto rank = 1; rank < world_size; rank++) {
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
}
});
} else {
// With encrypted column-wise data split, we distribute the best split candidates
// from Rank 0 to all other workers
auto rc = collective::Broadcast(
ctx, linalg::MakeVec(out_splits.data(), out_splits.size(), ctx->Device()), 0);
collective::SafeColl(rc);
}
}
auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
auto d_entries = out_entries;
auto device_cats_accessor = this->DeviceCatStorage(nidx);
Expand Down
129 changes: 126 additions & 3 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
#include "row_partitioner.cuh"
#include "xgboost/base.h"

#include "../../collective/allgather.h" // for AllgatherV

#include "../../common/device_helpers.cuh"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../../plugin/federated/federated_hist.h" // for FederataedHistPolicy
#else
#include "../../common/error_msg.h" // for NoFederated
#endif

namespace xgboost::tree {
namespace {
struct Pair {
Expand Down Expand Up @@ -354,13 +363,127 @@ void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor con
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
}

void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
struct ReadMatrixFunction {
EllpackDeviceAccessor matrix;
int k_cols;
bst_float* matrix_data_d;
ReadMatrixFunction(EllpackDeviceAccessor matrix, int k_cols, bst_float* matrix_data_d)
: matrix(std::move(matrix)), k_cols(k_cols), matrix_data_d(matrix_data_d) {}

__device__ void operator()(size_t global_idx) {
auto row = global_idx / k_cols;
auto col = global_idx % k_cols;
auto value = matrix.GetBinIndex(row, col);
if (isnan(static_cast<float>(value))) {
value = -1;
}
matrix_data_d[global_idx] = value;
}
};

void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram,
GradientQuantiser rounding) {
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding);
GradientQuantiser rounding, MetaInfo const& info) {
auto IsSecureVertical = !info.IsRowSplit() && collective::IsDistributed()
&& collective::IsEncrypted();
if (!IsSecureVertical) {
// Regular training, build histogram locally
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups,
gpair, ridx, histogram, rounding);
return;
}
#if defined(XGBOOST_USE_FEDERATED)
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if (!is_aggr_context_initialized) {
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs =
common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());

// Get bin_idx matrix
auto kRows = matrix.n_rows;
auto kCols = matrix.NumFeatures();
std::vector<int32_t> h_bin_idx(kRows * kCols);
// Access GPU matrix data
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx =
common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
is_aggr_context_initialized = true;
}

// get row indices from device
std::vector<uint32_t> h_ridx(ridx.size());
dh::CopyDeviceSpanToVector(&h_ridx, ridx);
// necessary conversions to fit plugin expectations
std::vector<uint64_t> h_ridx_64(ridx.size());
for (int i = 0; i < ridx.size(); i++) {
h_ridx_64[i] = h_ridx[i];
}
std::vector<std::uint64_t const *> ptrs(1);
std::vector<std::size_t> sizes(1);
std::vector<bst_node_t> nodes(1);
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx_64.data());
sizes[0] = h_ridx_64.size();
nodes[0] = 0;

// Transmit row indices to plugin and get encrypted histogram
auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data),
&recv_segments, &hist_entries));

if (collective::GetRank() != 0) {
// Below is only needed for lable owner
return;
}

// Call the plugin to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
#else
LOG(FATAL) << error::NoFederated();
#endif
}
} // namespace xgboost::tree
8 changes: 5 additions & 3 deletions src/tree/gpu_hist/histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,16 @@ class DeviceHistogramBuilder {
public:
DeviceHistogramBuilder();
~DeviceHistogramBuilder();

// Whether to secure aggregation context has been initialized
bool is_aggr_context_initialized{false};
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory);
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
void BuildHistogram(Context const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
MetaInfo const& info);
};
} // namespace xgboost::tree
#endif // HISTOGRAM_CUH_
4 changes: 2 additions & 2 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ struct GPUHistMakerDevice {
void BuildHist(int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
this->histogram_.BuildHistogram(ctx_, page->GetDeviceAccessor(ctx_->Device()),
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
d_node_hist, *quantiser);
d_node_hist, *quantiser, info_);
}

// Attempt to do subtraction trick
Expand Down
Loading
Loading