Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the passive sync error #11

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 92 additions & 89 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -394,99 +394,102 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
// Regular training, build histogram locally
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups,
gpair, ridx, histogram, rounding);
} else {
return;
}
#if defined(XGBOOST_USE_FEDERATED)
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if (!is_aggr_context_initialized_) {
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs =
common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());

// Get bin_idx matrix
auto kRows = matrix.n_rows;
auto kCols = matrix.NumFeatures();
std::vector<int32_t> h_bin_idx(kRows * kCols);
// Access GPU matrix data
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx =
common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
is_aggr_context_initialized_ = true;
}
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if (!is_aggr_context_initialized_) {
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs =
common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());

// Get bin_idx matrix
auto kRows = matrix.n_rows;
auto kCols = matrix.NumFeatures();
std::vector<int32_t> h_bin_idx(kRows * kCols);
// Access GPU matrix data
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx =
common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
is_aggr_context_initialized_ = true;
}

// get row indices from device
std::vector<uint32_t> h_ridx(ridx.size());
dh::CopyDeviceSpanToVector(&h_ridx, ridx);
// necessary conversions to fit plugin expectations
std::vector<uint64_t> h_ridx_64(ridx.size());
for (int i = 0; i < ridx.size(); i++) {
h_ridx_64[i] = h_ridx[i];
}
std::vector<std::uint64_t const *> ptrs(1);
std::vector<std::size_t> sizes(1);
std::vector<bst_node_t> nodes(1);
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx_64.data());
sizes[0] = h_ridx_64.size();
nodes[0] = 0;

// Transmit row indices to plugin and get encrypted histogram
auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data),
&recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
// This is only needed by Rank 0
if (collective::GetRank() == 0) {
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
}
// get row indices from device
std::vector<uint32_t> h_ridx(ridx.size());
dh::CopyDeviceSpanToVector(&h_ridx, ridx);
// necessary conversions to fit plugin expectations
std::vector<uint64_t> h_ridx_64(ridx.size());
for (int i = 0; i < ridx.size(); i++) {
h_ridx_64[i] = h_ridx[i];
}
std::vector<std::uint64_t const *> ptrs(1);
std::vector<std::size_t> sizes(1);
std::vector<bst_node_t> nodes(1);
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx_64.data());
sizes[0] = h_ridx_64.size();
nodes[0] = 0;

// Transmit row indices to plugin and get encrypted histogram
auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data),
&recv_segments, &hist_entries));

if (collective::GetRank() != 0) {
// This is only needed for lable owner
return;
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
#else
LOG(FATAL) << error::NoFederated();
#endif
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
#else
LOG(FATAL) << error::NoFederated();
#endif

}
} // namespace xgboost::tree
Loading