Skip to content

Commit

Permalink
add the missing aggr piece to secure hori GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Jul 18, 2024
1 parent b448dff commit 7d6d592
Showing 1 changed file with 44 additions and 48 deletions.
92 changes: 44 additions & 48 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -628,61 +628,57 @@ struct GPUHistMakerDevice {
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
auto hist_vec = linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist), n, ctx_->Device());

// print out the first histogram with iterator
//std::vector<int64_t> entry(hist_vec.Values().size());
//dh::CopyDeviceSpanToVector(&entry, hist_vec.Values());
//printf("Non-enc: Rank %d Before AllReduce: %ld\n", collective::GetRank(), entry[0]);

auto rc = collective::GlobalSum(
ctx_, info_, hist_vec);
SafeColl(rc);

// print out the first histogram with iterator
//dh::CopyDeviceSpanToVector(&entry, hist_vec.Values());
//printf("Non-enc: Rank %d After AllReduce: %ld\n", collective::GetRank(), entry[0]);

monitor.Stop("AllReduce");
}

void AllReduceHistEncrypted(int nidx, int num_histograms) {
monitor.Start("AllReduceEncrypted");
// Get encryption plugin
decltype(std::declval<collective::FederatedComm>().EncryptionPlugin()) plugin;
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx_, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
plugin = fed.EncryptionPlugin();

// Get the histogram data
std::size_t n = page->Cuts().TotalBins() * 2 * num_histograms;
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
auto hist_vec = linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist), n, ctx_->Device());

// copy the histogram out of GPU memory
common::Span<std::int8_t> erased = common::EraseType(hist_vec.Values());
std::vector<std::int8_t> h_data(erased.size());
cudaMemcpy(h_data.data(), erased.data(), erased.size(), cudaMemcpyDeviceToHost);

// call the encryption plugin
auto src_hist = common::Span{reinterpret_cast<double const *>(h_data.data()), n};
auto hist_buf = plugin->BuildEncryptedHistHori(src_hist);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc =
collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
collective::SafeColl(rc);

// call the encryption plugin to aggregate the histograms
auto hist_aggr =
plugin->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// copy the aggregated histogram back to GPU memory
cudaMemcpy(erased.data(), hist_aggr.data(), erased.size(), cudaMemcpyHostToDevice);

monitor.Stop("AllReduceEncrypted");
monitor.Start("AllReduceEncrypted");
// Get encryption plugin
decltype(std::declval<collective::FederatedComm>().EncryptionPlugin()) plugin;
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx_, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
plugin = fed.EncryptionPlugin();

// Get the histogram data
std::size_t n = page->Cuts().TotalBins() * 2 * num_histograms;
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
auto hist_vec = linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist), n, ctx_->Device());

// copy the histogram out of GPU memory
common::Span<std::int8_t> erased = common::EraseType(hist_vec.Values());
std::vector<std::int8_t> h_data(erased.size());
cudaMemcpy(h_data.data(), erased.data(), erased.size(), cudaMemcpyDeviceToHost);

// call the encryption plugin
auto src_hist = common::Span{reinterpret_cast<double const *>(h_data.data()), n};
auto hist_buf = plugin->BuildEncryptedHistHori(src_hist);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc = collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
collective::SafeColl(rc);

// call the encryption plugin to decode the histograms
auto hist_aggr = plugin->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// reinterpret the aggregated histogram as a int64_t and aggregate
auto hist_aggr_64 = common::Span{reinterpret_cast<std::int64_t *>(hist_aggr.data()), hist_aggr.size()};
int num_ranks = collective::GlobalCommGroup()->World();
for (size_t i = 0; i < n; i++) {
for (int j = 1; j < num_ranks; j++) {
hist_aggr_64[i] = hist_aggr_64[i] + hist_aggr_64[i + j * n];
}
}

// copy the aggregated histogram back to GPU memory
cudaMemcpy(erased.data(), hist_aggr_64.data(), erased.size(), cudaMemcpyHostToDevice);

monitor.Stop("AllReduceEncrypted");
}

/**
Expand Down

0 comments on commit 7d6d592

Please sign in to comment.