Skip to content

Commit

Permalink
code lint corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Jul 19, 2024
1 parent 7d6d592 commit 7421cfa
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
#include "../collective/communicator-inl.h"
#include "../collective/allgather.h" // for AllgatherV

#include <stdio.h>
namespace xgboost::tree {
#if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
Expand Down Expand Up @@ -660,14 +659,17 @@ struct GPUHistMakerDevice {
// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc = collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
auto rc = collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf),
&recv_segments, &hist_entries);
collective::SafeColl(rc);

// call the encryption plugin to decode the histograms
auto hist_aggr = plugin->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
auto hist_aggr = plugin->SyncEncryptedHistHori(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// reinterpret the aggregated histogram as a int64_t and aggregate
auto hist_aggr_64 = common::Span{reinterpret_cast<std::int64_t *>(hist_aggr.data()), hist_aggr.size()};
auto hist_aggr_64 = common::Span{
reinterpret_cast<std::int64_t *>(hist_aggr.data()), hist_aggr.size()};
int num_ranks = collective::GlobalCommGroup()->World();
for (size_t i = 0; i < n; i++) {
for (int j = 1; j < num_ranks; j++) {
Expand Down Expand Up @@ -718,8 +720,7 @@ struct GPUHistMakerDevice {
// If secure horizontal, perform AllReduce by calling the encryption plugin
if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) {
this->AllReduceHistEncrypted(hist_nidx.at(0), hist_nidx.size());
}
else {
} else {
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
}

Expand All @@ -733,8 +734,7 @@ struct GPUHistMakerDevice {
this->BuildHist(subtraction_trick_nidx);
if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) {
this->AllReduceHistEncrypted(subtraction_trick_nidx, 1);
}
else {
} else {
this->AllReduceHist(subtraction_trick_nidx, 1);
}
}
Expand Down Expand Up @@ -812,8 +812,7 @@ struct GPUHistMakerDevice {
this->BuildHist(kRootNIdx);
if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) {
this->AllReduceHistEncrypted(kRootNIdx, 1);
}
else {
} else {
this->AllReduceHist(kRootNIdx, 1);
}

Expand Down

0 comments on commit 7421cfa

Please sign in to comment.