Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More refactoring to take advantage of collective aggregators #9081

Merged
merged 4 commits into from
Apr 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
@@ -196,6 +196,14 @@ class MetaInfo {
*/
bool IsVerticalFederated() const;

/*!
* \brief A convenient method to check if the MetaInfo should contain labels.
*
* Normally we assume labels are available everywhere. The only exception is in vertical federated
* learning where labels are only available on worker 0.
*/
bool ShouldHaveLabels() const;

private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
10 changes: 4 additions & 6 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
@@ -31,18 +31,16 @@ namespace collective {
* @param buffer The buffer storing the results.
* @param size The size of the buffer.
* @param function The function used to calculate the results.
* @param args Arguments to the function.
*/
template <typename Function, typename T, typename... Args>
void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& function,
Args&&... args) {
template <typename Function>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is type erasion necessary? (void*)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing the broadcast, we don't need to know the type, only the size (https://github.com/dmlc/xgboost/blob/master/src/collective/communicator-inl.h#L128), so in here we don't really care about the type either.

if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)(std::forward<Args>(args)...);
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
@@ -55,7 +53,7 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)(std::forward<Args>(args)...);
std::forward<Function>(function)();
}
}

10 changes: 4 additions & 6 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
@@ -45,20 +45,18 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b

if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->Info().IsColumnSplit(), n_threads);
HostSketchContainer::UseGroup(info), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
container.MakeCuts(m->Info(), &out);
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->Info().IsColumnSplit(), n_threads};
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(&out);
container.MakeCuts(m->Info(), &out);
}

return out;
26 changes: 12 additions & 14 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
#include <limits>
#include <utility>

#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../data/adapter.h"
#include "categorical.h"
@@ -18,13 +19,12 @@ template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
int32_t max_bins,
Span<FeatureType const> feature_types,
bool use_group, bool col_split,
bool use_group,
int32_t n_threads)
: feature_types_(feature_types.cbegin(), feature_types.cend()),
columns_size_{std::move(columns_size)},
max_bins_{max_bins},
use_group_ind_{use_group},
col_split_{col_split},
n_threads_{n_threads} {
monitor_.Init(__func__);
CHECK_NE(columns_size_.size(), 0);
@@ -202,10 +202,10 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories() {
void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || col_split_) {
if (world_size == 1 || info.IsColumnSplit()) {
return;
}

@@ -273,6 +273,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories() {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllReduce(
MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t>* p_num_cuts) {
monitor_.Start(__func__);
@@ -281,7 +282,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";

AllreduceCategories();
AllreduceCategories(info);

auto& num_cuts = *p_num_cuts;
CHECK_EQ(num_cuts.size(), 0);
@@ -292,10 +293,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(

// Prune the intermediate num cuts for synchronization.
std::vector<bst_row_t> global_column_size(columns_size_);
if (!col_split_) {
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
global_column_size.size());
}
collective::GlobalSum(info, &global_column_size);

ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(
@@ -316,7 +314,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
});

auto world = collective::GetWorldSize();
if (world == 1 || col_split_) {
if (world == 1 || info.IsColumnSplit()) {
monitor_.Stop(__func__);
return;
}
@@ -382,11 +380,11 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const& info, HistogramCuts* cuts) {
monitor_.Start(__func__);
std::vector<typename WQSketch::SummaryContainer> reduced;
std::vector<int32_t> num_cuts;
this->AllReduce(&reduced, &num_cuts);
this->AllReduce(info, &reduced, &num_cuts);

cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
@@ -443,8 +441,8 @@ template class SketchContainerImpl<WXQuantileSketch<float, float>>;

HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group,
bool col_split, int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} {
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
monitor_.Init(__func__);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
16 changes: 7 additions & 9 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
@@ -789,7 +789,6 @@ class SketchContainerImpl {
std::vector<bst_row_t> columns_size_;
int32_t max_bins_;
bool use_group_ind_{false};
bool col_split_;
int32_t n_threads_;
bool has_categorical_{false};
Monitor monitor_;
@@ -802,7 +801,7 @@ class SketchContainerImpl {
* \param use_group whether is assigned to group to data instance.
*/
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
common::Span<FeatureType const> feature_types, bool use_group, bool col_split,
common::Span<FeatureType const> feature_types, bool use_group,
int32_t n_threads);

static bool UseGroup(MetaInfo const &info) {
@@ -829,7 +828,7 @@ class SketchContainerImpl {
std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches);
// Merge sketches from all workers.
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t> *p_num_cuts);

template <typename Batch, typename IsValid>
@@ -883,11 +882,11 @@ class SketchContainerImpl {
/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});

void MakeCuts(HistogramCuts* cuts);
void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);

private:
// Merge all categories from other workers.
void AllreduceCategories();
void AllreduceCategories(MetaInfo const& info);
};

class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
@@ -896,8 +895,7 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl

public:
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, bool col_split,
int32_t n_threads);
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);

template <typename Batch>
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
@@ -993,9 +991,9 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,

public:
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, bool col_split,
std::vector<size_t> columns_size, bool use_group,
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} {
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
monitor_.Init(__func__);
sketches_.resize(columns_size.size());
size_t i = 0;
4 changes: 4 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
@@ -774,6 +774,10 @@ bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}

bool MetaInfo::ShouldHaveLabels() const {
return !IsVerticalFederated() || collective::GetRank() == 0;
}

using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

4 changes: 2 additions & 2 deletions src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
@@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->Info().IsColumnSplit(), ctx_.Threads()});
ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
@@ -228,7 +228,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
CHECK_EQ(accumulated_rows, Info().num_row_);

CHECK(p_sketch);
p_sketch->MakeCuts(&cuts);
p_sketch->MakeCuts(Info(), &cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), n_features);
70 changes: 33 additions & 37 deletions src/objective/adaptive.cc
Original file line number Diff line number Diff line change
@@ -99,44 +99,40 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
predt.Size() / info.num_row_);

if (!info.IsVerticalFederated() || collective::GetRank() == 0) {
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);

auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_);

auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt(row_idx, group_idx);
collective::ApplyWithLabels(
info, static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float), [&] {
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);

auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_);

auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt(row_idx, group_idx);
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_weights(row_idx);
});

float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
}
quantiles.at(k) = q;
});
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_weights(row_idx);
});

float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
}
quantiles.at(k) = q;
});
}

if (info.IsVerticalFederated()) {
collective::Broadcast(static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float),
0);
}

UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree);
}
2 changes: 1 addition & 1 deletion src/objective/quantile_obj.cu
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ class QuantileRegression : public ObjFunction {
bst_target_t Targets(MetaInfo const& info) const override {
auto const& alpha = param_.quantile_alpha.Get();
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
if (!info.IsVerticalFederated() || collective::GetRank() == 0) {
if (info.ShouldHaveLabels()) {
CHECK_EQ(info.labels.Shape(1), 1)
<< "Multi-target is not yet supported by the quantile loss.";
}
16 changes: 8 additions & 8 deletions tests/cpp/common/test_quantile.cc
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
auto hess = Span<float const>{hessian};

ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
column_size, false, AllThreadsForTest());

if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
@@ -86,15 +86,15 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
}

HistogramCuts distributed_cuts;
sketch_distributed.MakeCuts(&distributed_cuts);
sketch_distributed.MakeCuts(m->Info(), &distributed_cuts);

// Generate cuts for single node environment
collective::Finalize();
CHECK_EQ(collective::GetWorldSize(), 1);
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
m->Info().num_row_ = world * rows;
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
column_size, false, AllThreadsForTest());
m->Info().num_row_ = rows;

for (auto rank = 0; rank < world; ++rank) {
@@ -117,7 +117,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
}

HistogramCuts single_node_cuts;
sketch_on_single_node.MakeCuts(&single_node_cuts);
sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts);

auto const& sptrs = single_node_cuts.Ptrs();
auto const& dptrs = distributed_cuts.Ptrs();
@@ -205,7 +205,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
HistogramCuts distributed_cuts;
{
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, true, AllThreadsForTest());
column_size, false, AllThreadsForTest());

std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
@@ -219,7 +219,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
}
}

sketch_distributed.MakeCuts(&distributed_cuts);
sketch_distributed.MakeCuts(m->Info(), &distributed_cuts);
}

// Generate cuts for single node environment
@@ -228,7 +228,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
HistogramCuts single_node_cuts;
{
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
column_size, false, false, AllThreadsForTest());
column_size, false, AllThreadsForTest());

std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
@@ -242,7 +242,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) {
}
}

sketch_on_single_node.MakeCuts(&single_node_cuts);
sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts);
}

auto const& sptrs = single_node_cuts.Ptrs();