Skip to content

Commit

Permalink
Federated plugin for histogram. (dmlc#10534)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 15, 2024
1 parent ed47ff5 commit 78e4533
Show file tree
Hide file tree
Showing 15 changed files with 623 additions and 231 deletions.
3 changes: 2 additions & 1 deletion plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc federated_plugin.cc
objxgboost PRIVATE
federated_plugin.cc federated_hist.cc federated_tracker.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
Expand Down
10 changes: 7 additions & 3 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "federated_coll.h"

Expand All @@ -8,11 +8,15 @@

#include <algorithm> // for copy_n

#include "../../src/collective/allgather.h"
#include "../../src/common/common.h" // for AssertGPUSupport
#include "federated_comm.h" // for FederatedComm
#include "xgboost/collective/result.h" // for Result

#if !defined(XGBOOST_USE_CUDA)

#include "../../src/common/common.h" // for AssertGPUSupport

#endif // !defined(XGBOOST_USE_CUDA)

namespace xgboost::collective {
namespace {
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
Expand Down
159 changes: 159 additions & 0 deletions plugin/federated/federated_hist.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "federated_hist.h"

#include "../../src/collective/allgather.h" // for AllgatherV
#include "../../src/collective/communicator-inl.h" // for GetRank
#include "../../src/tree/hist/histogram.h" // for SubtractHistParallel, BuildSampleHistograms

namespace xgboost::tree {
namespace {
// Copy the bins into a dense matrix.
auto CopyBinsToDense(Context const *ctx, GHistIndexMatrix const &gidx) {
auto n_samples = gidx.Size();
auto n_features = gidx.Features();
std::vector<bst_bin_t> bins(n_samples * n_features);
auto bins_view = linalg::MakeTensorView(ctx, bins, n_samples, n_features);
common::ParallelFor(n_samples, ctx->Threads(), [&](auto ridx) {
for (bst_feature_t fidx = 0; fidx < n_features; fidx++) {
bins_view(ridx, fidx) = gidx.GetGindex(ridx, fidx);
}
});
return bins;
}
} // namespace

template <bool any_missing>
void FederataedHistPolicy::DoBuildLocalHistograms(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *p_buffer) {
if (is_col_split_) {
// Copy the gidx information to the secure worker for encrypted histogram
// computation. This is copied as we don't want the plugin to handle the bin
// compression, which is quite internal of XGBoost.

// FIXME: this can be done during reset.
if (!is_gidx_initialized_) {
auto bins = CopyBinsToDense(ctx_, gidx);
auto cuts = gidx.Cuts().Ptrs();
plugin_->Reset(cuts, bins);
is_gidx_initialized_ = true;
}

// Share the row set collection without copy.
std::vector<std::uint64_t const *> ptrs(nodes_to_build.size());
std::vector<std::size_t> sizes(nodes_to_build.size());
std::vector<bst_node_t> nodes(nodes_to_build.size());
for (std::size_t i = 0; i < nodes_to_build.size(); ++i) {
auto nidx = nodes_to_build[i];
ptrs[i] = row_set_collection[nidx].begin();
sizes[i] = row_set_collection[nidx].Size();
nodes[i] = nidx;
}
hist_data_ = this->plugin_->BuildEncryptedHistVert(ptrs, sizes, nodes);
} else {
BuildSampleHistograms<any_missing>(this->ctx_->Threads(), space, gidx, nodes_to_build,
row_set_collection, gpair_h, force_read_by_column, p_buffer);
}
}

template void FederataedHistPolicy::DoBuildLocalHistograms<true>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);
template void FederataedHistPolicy::DoBuildLocalHistograms<false>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);

namespace {
// The label owner needs to gather the result from all workers.
void GatherWorkerHist(common::Span<double> hist_aggr, std::int32_t n_workers,
std::vector<bst_node_t> const &nodes_to_build, bst_bin_t n_total_bins,
tree::BoundedHistCollection *p_hist) {
bst_idx_t worker_size = hist_aggr.size() / n_workers;
bst_node_t n_nodes = nodes_to_build.size();
auto &hist = *p_hist;
// for each worker
for (auto widx = 0; widx < n_workers; ++widx) {
auto worker_hist = hist_aggr.subspan(widx * worker_size, worker_size);
// for each node
for (bst_node_t nidx_in_set = 0; nidx_in_set < n_nodes; ++nidx_in_set) {
auto hist_size = n_total_bins * kHist2F64; // Histogram size for one node.
auto hist_src = worker_hist.subspan(hist_size * nidx_in_set, hist_size);
auto hist_src_g = common::RestoreType<GradientPairPrecise>(hist_src);
auto hist_dst = hist[nodes_to_build[nidx_in_set]];
CHECK_EQ(hist_src_g.size(), hist_dst.size());
common::IncrementHist(hist_dst, hist_src_g, 0, hist_dst.size());
}
}
}
} // namespace

void FederataedHistPolicy::DoSyncHistogram(common::BlockedSpace2d const &space,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *p_buffer,
tree::BoundedHistCollection *p_hist) {
auto n_total_bins = p_buffer->TotalBins();
CHECK(!nodes_to_build.empty());

auto &hist = *p_hist;
if (is_col_split_) {
// Under secure vertical mode, we perform allgather to get the global histogram. Note
// that only the label owner (rank == 0) needs the global histogram

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx_, linalg::MakeVec(hist_data_), &recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin_->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Update histogram for the label owner
if (collective::GetRank() == 0) {
std::int32_t n_workers = collective::GetWorldSize();
CHECK_EQ(hist_aggr.size() % n_workers, 0);
// Initialize histogram. For the normal case, this is done by the parallel hist
// buffer. We should try to unify the code paths.
for (auto nidx : nodes_to_build) {
auto hist_dst = hist[nidx];
std::fill_n(hist_dst.data(), hist_dst.size(), GradientPairPrecise{});
}
GatherWorkerHist(hist_aggr, n_workers, nodes_to_build, n_total_bins, p_hist);
}
} else {
common::ParallelFor2d(space, this->ctx_->Threads(), [&](std::size_t node, common::Range1d r) {
// Merging histograms from each thread.
p_buffer->ReduceHist(node, r.begin(), r.end());
});
// Encrtyped mode, we need to call the plugin to perform encryption and decryption.
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * kHist2F64;
auto src_hist = common::Span{reinterpret_cast<double const *>(hist[first_nidx].data()), n};
auto hist_buf = plugin_->BuildEncryptedHistHori(src_hist);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc =
collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
collective::SafeColl(rc);

auto hist_aggr =
plugin_->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
// Assign the aggregated histogram back to the local histogram
auto hist_dst = reinterpret_cast<double *>(hist[first_nidx].data());
std::copy_n(hist_aggr.data(), hist_aggr.size(), hist_dst);
}
}
} // namespace xgboost::tree
58 changes: 58 additions & 0 deletions plugin/federated/federated_hist.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* Copyright 2024, XGBoost contributors
*/
#pragma once
#include <cstdint> // for int32_t
#include <vector> // for vector

#include "../../src/collective/comm_group.h" // for GlobalCommGroup
#include "../../src/common/hist_util.h" // for ParallelGHistBuilder
#include "../../src/common/row_set.h" // for RowSetCollection
#include "../../src/common/threading_utils.h" // for BlockedSpace2d
#include "../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../src/tree/hist/hist_cache.h" // for BoundedHistCollection
#include "federated_comm.h" // for FederatedComm
#include "xgboost/base.h" // for GradientPair
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree

namespace xgboost::tree {
/**
* @brief Federated histogram build policy
*/
class FederataedHistPolicy {
// fixme: duplicated code
bool is_col_split_{false};
bool is_distributed_{false};
decltype(std::declval<collective::FederatedComm>().EncryptionPlugin()) plugin_;
xgboost::common::Span<std::uint8_t> hist_data_;
// Only initialize the aggregation context once
bool is_gidx_initialized_{false};
Context const* ctx_;

public:
void DoReset(Context const *ctx, bool is_distributed, bool is_col_split) {
this->is_distributed_ = is_distributed;
CHECK(is_distributed);
this->ctx_ = ctx;
this->is_col_split_ = is_col_split;
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
plugin_ = fed.EncryptionPlugin();
CHECK(is_distributed_) << "Unreachable. Single node training can not be federated.";
}

template <bool any_missing>
void DoBuildLocalHistograms(common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column,
common::ParallelGHistBuilder *p_buffer);

void DoSyncHistogram(common::BlockedSpace2d const &space,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *p_buffer, tree::BoundedHistCollection *p_hist);
};
} // namespace xgboost::tree
6 changes: 6 additions & 0 deletions plugin/federated/federated_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
* - Build histogram for vertical federated learning.
* - Build histogram for horizontal federated learning.
*
* Since we don't require the plugin to have network capability, the synchronization is
* performed in XGBoost. As a result, the build procedure is divided into four steps,
* first we need to build a local histogram, then encrypt it with the plugin. Afterward,
* the control returns to XBGoost, which is responsible for synchronization. Lastly, the
* plugin will recieve the synchronization result and return the decrypted histogram.
*
* See below function prototypes for details. All prototypes are for C functions that are
* suitable for `dlopen`.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

#include <algorithm> // for min
#include <atomic> // for atomic
#include <cinttypes> // for uint32_t
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for make_unique
#include <vector>

Expand Down
4 changes: 2 additions & 2 deletions src/tree/hist/histogram.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "histogram.h"

Expand All @@ -10,7 +10,7 @@

#include "../../common/transform_iterator.h" // for MakeIndexTransformIter
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "xgboost/logging.h" // for CHECK_NE
#include "xgboost/logging.h" // for CHECK_EQ
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree

Expand Down
Loading

0 comments on commit 78e4533

Please sign in to comment.