From 397ade38d2906e4c39a88ce8927d6fbaed801d74 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 1 Aug 2024 19:06:36 -0400 Subject: [PATCH] Added check for passive when sync histo for vertical and removed some nested blocks --- src/tree/gpu_hist/histogram.cu | 181 +++++++++++++++++---------------- 1 file changed, 92 insertions(+), 89 deletions(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 68427d21322c..93678496cda1 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -394,99 +394,102 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx, // Regular training, build histogram locally this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding); - } else { + return; + } #if defined(XGBOOST_USE_FEDERATED) - // Encrypted vertical, build histogram using federated plugin - auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU()); - auto const &fed = dynamic_cast(comm); - auto plugin = fed.EncryptionPlugin(); - - // Transmit matrix to plugin - if (!is_aggr_context_initialized_) { - // Get cutptrs - std::vector h_cuts_ptr(matrix.feature_segments.size()); - dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments); - common::Span cutptrs = - common::Span(h_cuts_ptr.data(), h_cuts_ptr.size()); - - // Get bin_idx matrix - auto kRows = matrix.n_rows; - auto kCols = matrix.NumFeatures(); - std::vector h_bin_idx(kRows * kCols); - // Access GPU matrix data - thrust::device_vector matrix_d(kRows * kCols); - dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get())); - thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin()); - common::Span bin_idx = - common::Span(h_bin_idx.data(), h_bin_idx.size()); - - // Initialize plugin context - plugin->Reset(h_cuts_ptr, h_bin_idx); - is_aggr_context_initialized_ = true; - } + // Encrypted vertical, build histogram using federated plugin + auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU()); + auto const &fed = dynamic_cast(comm); + auto plugin = fed.EncryptionPlugin(); + + // Transmit matrix to plugin + if (!is_aggr_context_initialized_) { + // Get cutptrs + std::vector h_cuts_ptr(matrix.feature_segments.size()); + dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments); + common::Span cutptrs = + common::Span(h_cuts_ptr.data(), h_cuts_ptr.size()); + + // Get bin_idx matrix + auto kRows = matrix.n_rows; + auto kCols = matrix.NumFeatures(); + std::vector h_bin_idx(kRows * kCols); + // Access GPU matrix data + thrust::device_vector matrix_d(kRows * kCols); + dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get())); + thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin()); + common::Span bin_idx = + common::Span(h_bin_idx.data(), h_bin_idx.size()); + + // Initialize plugin context + plugin->Reset(h_cuts_ptr, h_bin_idx); + is_aggr_context_initialized_ = true; + } - // get row indices from device - std::vector h_ridx(ridx.size()); - dh::CopyDeviceSpanToVector(&h_ridx, ridx); - // necessary conversions to fit plugin expectations - std::vector h_ridx_64(ridx.size()); - for (int i = 0; i < ridx.size(); i++) { - h_ridx_64[i] = h_ridx[i]; - } - std::vector ptrs(1); - std::vector sizes(1); - std::vector nodes(1); - ptrs[0] = reinterpret_cast(h_ridx_64.data()); - sizes[0] = h_ridx_64.size(); - nodes[0] = 0; - - // Transmit row indices to plugin and get encrypted histogram - auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes); - - // Perform AllGather - HostDeviceVector hist_entries; - std::vector recv_segments; - collective::SafeColl( - collective::AllgatherV(ctx, linalg::MakeVec(hist_data), - &recv_segments, &hist_entries)); - - // Call the plugin here to get the resulting histogram. Histogram from all workers are - // gathered to the label owner - common::Span hist_aggr = - plugin->SyncEncryptedHistVert( - common::RestoreType(hist_entries.HostSpan())); - - // Post process the AllGathered data - // This is only needed by Rank 0 - if (collective::GetRank() == 0) { - auto world_size = collective::GetWorldSize(); - std::vector host_histogram(histogram.size()); - for (auto i = 0; i < histogram.size(); i++) { - double grad = 0.0; - double hess = 0.0; - for (auto rank = 0; rank < world_size; ++rank) { - auto idx = rank * histogram.size() + i; - grad += hist_aggr[idx * 2]; - hess += hist_aggr[idx * 2 + 1]; - } - GradientPairPrecise hist_item(grad, hess); - GradientPairPrecise hist_item_empty(0.0, 0.0); - if (collective::GetRank() != 0) { - host_histogram[i] = rounding.ToFixedPoint(hist_item_empty); - } else { - host_histogram[i] = rounding.ToFixedPoint(hist_item); - } - } + // get row indices from device + std::vector h_ridx(ridx.size()); + dh::CopyDeviceSpanToVector(&h_ridx, ridx); + // necessary conversions to fit plugin expectations + std::vector h_ridx_64(ridx.size()); + for (int i = 0; i < ridx.size(); i++) { + h_ridx_64[i] = h_ridx[i]; + } + std::vector ptrs(1); + std::vector sizes(1); + std::vector nodes(1); + ptrs[0] = reinterpret_cast(h_ridx_64.data()); + sizes[0] = h_ridx_64.size(); + nodes[0] = 0; + + // Transmit row indices to plugin and get encrypted histogram + auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes); + + // Perform AllGather + HostDeviceVector hist_entries; + std::vector recv_segments; + collective::SafeColl( + collective::AllgatherV(ctx, linalg::MakeVec(hist_data), + &recv_segments, &hist_entries)); + + if (collective::GetRank() != 0) { + // This is only needed for lable owner + return; + } - // copy the aggregated histogram back to GPU memory - // at this point, the histogram contains full information from all parties - dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(), - histogram.size() * sizeof(GradientPairInt64), - cudaMemcpyHostToDevice)); + // Call the plugin here to get the resulting histogram. Histogram from all workers are + // gathered to the label owner. + common::Span hist_aggr = + plugin->SyncEncryptedHistVert( + common::RestoreType(hist_entries.HostSpan())); + + // Post process the AllGathered data + auto world_size = collective::GetWorldSize(); + std::vector host_histogram(histogram.size()); + for (auto i = 0; i < histogram.size(); i++) { + double grad = 0.0; + double hess = 0.0; + for (auto rank = 0; rank < world_size; ++rank) { + auto idx = rank * histogram.size() + i; + grad += hist_aggr[idx * 2]; + hess += hist_aggr[idx * 2 + 1]; + } + GradientPairPrecise hist_item(grad, hess); + GradientPairPrecise hist_item_empty(0.0, 0.0); + if (collective::GetRank() != 0) { + host_histogram[i] = rounding.ToFixedPoint(hist_item_empty); + } else { + host_histogram[i] = rounding.ToFixedPoint(hist_item); } - #else - LOG(FATAL) << error::NoFederated(); - #endif } + + // copy the aggregated histogram back to GPU memory + // at this point, the histogram contains full information from all parties + dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(), + histogram.size() * sizeof(GradientPairInt64), + cudaMemcpyHostToDevice)); +#else + LOG(FATAL) << error::NoFederated(); +#endif + } } // namespace xgboost::tree