Skip to content

Commit

Permalink
extract hist vec, mark interface ops, code ready for processor interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Apr 22, 2024
1 parent 86e3969 commit 8e40b71
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class HistEvaluator {
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties for secure vertical
// All other cases need it
if (! ((is_col_split_) && (is_secure_) && (collective::GetRank() == 0))) {
if (!((is_col_split_) && (is_secure_) && (collective::GetRank() == 0))) {
// Evaluate the splits for each feature
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
Expand Down
22 changes: 18 additions & 4 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,27 @@ class HistogramBuilder {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
} else {
// Secure mode, we need to call interface to perform AllReduce
// Secure mode, we need to call interface to perform encryption and decryption
// for simplicity, use AllGather instead of AllReduce
// note that aggregation will be performed at server side
// note that the actual aggregation will be performed at server side
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
auto hist_to_aggr = std::vector<double>();
for (int hist_idx = 0; hist_idx < n; hist_idx++) {
double hist_item = reinterpret_cast<double *>(this->hist_[first_nidx].data())[hist_idx];
hist_to_aggr.push_back(hist_item);
}

// For FL with secure horizontal
//auto hist_buf_ckks = processor_instance->EncodeHist(hist_to_aggr);
//auto hist_entries = collective::AllgatherV(hist_buf_ckks);
//std::vector<double> hist_aggr = processor_instance->DecodeHist(hist_entries.data(), n);

// Assign the aggregated histogram back to the local histogram
collective::Allreduce<collective::Operation::kSum>(hist_to_aggr.data(), n);
for (int hist_idx = 0; hist_idx < n; hist_idx++) {
reinterpret_cast<double *>(this->hist_[first_nidx].data())[hist_idx] = hist_to_aggr[hist_idx];
}
}
}

Expand Down

0 comments on commit 8e40b71

Please sign in to comment.