Skip to content

Commit

Permalink
Added check for passive when sync histo for vertical and removed some…
Browse files Browse the repository at this point in the history
… nested blocks
  • Loading branch information
nvidianz committed Aug 1, 2024
1 parent ad21314 commit 397ade3
Showing 1 changed file with 92 additions and 89 deletions.
181 changes: 92 additions & 89 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -394,99 +394,102 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
// Regular training, build histogram locally
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups,
gpair, ridx, histogram, rounding);
} else {
return;
}
#if defined(XGBOOST_USE_FEDERATED)
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if (!is_aggr_context_initialized_) {
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs =
common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());

// Get bin_idx matrix
auto kRows = matrix.n_rows;
auto kCols = matrix.NumFeatures();
std::vector<int32_t> h_bin_idx(kRows * kCols);
// Access GPU matrix data
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx =
common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
is_aggr_context_initialized_ = true;
}
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if (!is_aggr_context_initialized_) {
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs =
common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());

// Get bin_idx matrix
auto kRows = matrix.n_rows;
auto kCols = matrix.NumFeatures();
std::vector<int32_t> h_bin_idx(kRows * kCols);
// Access GPU matrix data
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx =
common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
is_aggr_context_initialized_ = true;
}

// get row indices from device
std::vector<uint32_t> h_ridx(ridx.size());
dh::CopyDeviceSpanToVector(&h_ridx, ridx);
// necessary conversions to fit plugin expectations
std::vector<uint64_t> h_ridx_64(ridx.size());
for (int i = 0; i < ridx.size(); i++) {
h_ridx_64[i] = h_ridx[i];
}
std::vector<std::uint64_t const *> ptrs(1);
std::vector<std::size_t> sizes(1);
std::vector<bst_node_t> nodes(1);
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx_64.data());
sizes[0] = h_ridx_64.size();
nodes[0] = 0;

// Transmit row indices to plugin and get encrypted histogram
auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data),
&recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
// This is only needed by Rank 0
if (collective::GetRank() == 0) {
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
}
// get row indices from device
std::vector<uint32_t> h_ridx(ridx.size());
dh::CopyDeviceSpanToVector(&h_ridx, ridx);
// necessary conversions to fit plugin expectations
std::vector<uint64_t> h_ridx_64(ridx.size());
for (int i = 0; i < ridx.size(); i++) {
h_ridx_64[i] = h_ridx[i];
}
std::vector<std::uint64_t const *> ptrs(1);
std::vector<std::size_t> sizes(1);
std::vector<bst_node_t> nodes(1);
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx_64.data());
sizes[0] = h_ridx_64.size();
nodes[0] = 0;

// Transmit row indices to plugin and get encrypted histogram
auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data),
&recv_segments, &hist_entries));

if (collective::GetRank() != 0) {
// This is only needed for lable owner
return;
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
#else
LOG(FATAL) << error::NoFederated();
#endif
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
#else
LOG(FATAL) << error::NoFederated();
#endif

}
} // namespace xgboost::tree

0 comments on commit 397ade3

Please sign in to comment.