Skip to content

Commit

Permalink
[coll] Pass context to various functions. (#9772)
Browse files Browse the repository at this point in the history
* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
  • Loading branch information
trivialfis authored Nov 8, 2023
1 parent 6c0a190 commit 06bdc15
Show file tree
Hide file tree
Showing 45 changed files with 275 additions and 255 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class MetaInfo {
* in vertical federated learning, since each worker loads its own list of columns,
* we need to sum them.
*/
void SynchronizeNumberOfColumns();
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
Expand Down
12 changes: 6 additions & 6 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,20 +582,20 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
}

template <typename T, typename... S>
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T, ext> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, device};
}

template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T, ext> data, S &&...shape) {
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
}

template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
Expand Down
6 changes: 3 additions & 3 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace {
auto stub = fed->Handle();

BroadcastRequest request;
request.set_sequence_number(*sequence_number++);
request.set_sequence_number((*sequence_number)++);
request.set_rank(comm.Rank());
if (comm.Rank() != root) {
request.set_send_buffer(nullptr, 0);
Expand Down Expand Up @@ -90,9 +90,9 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,

Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
if (comm.World() == 1) {
return Success();
}
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
Expand Down
3 changes: 2 additions & 1 deletion src/collective/comm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <sstream> // for stringstream
#include <vector> // for vector

#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "broadcast.h" // for Broadcast
Expand Down Expand Up @@ -60,7 +61,7 @@ Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{dh::DefaultStream()} {
stream_{ctx->CUDACtx()->Stream()} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();
Expand Down
2 changes: 1 addition & 1 deletion src/collective/comm_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ CommGroup::CommGroup()
}

std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
static std::unique_ptr<collective::CommGroup> sptr;
static thread_local std::unique_ptr<collective::CommGroup> sptr;
if (!sptr) {
Json config{Null{}};
sptr.reset(CommGroup::Create(config));
Expand Down
3 changes: 2 additions & 1 deletion src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
thread_local std::unique_ptr<cub::CachingDeviceAllocator> allocator{
std::make_unique<cub::CachingDeviceAllocator>(2, 9, 29)};
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
for (auto const &page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(m->Info(), &out);
container.MakeCuts(ctx, m->Info(), &out);
} else {
SortedSketchContainer container{ctx,
max_bins,
Expand All @@ -61,7 +61,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(m->Info(), &out);
container.MakeCuts(ctx, m->Info(), &out);
}

return out;
Expand Down
2 changes: 1 addition & 1 deletion src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
}
}

sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
sketch_container.MakeCuts(ctx, &cuts, p_fmat->Info().IsColumnSplit());
return cuts;
}
} // namespace xgboost::common
27 changes: 12 additions & 15 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
#include "categorical.h"
#include "hist_util.h"

namespace xgboost {
namespace common {

namespace xgboost::common {
template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
std::vector<bst_row_t> columns_size,
Expand Down Expand Up @@ -129,7 +127,7 @@ struct QuantileAllreduce {
* \param rank rank of target worker
* \param fidx feature idx
*/
auto Values(int32_t rank, bst_feature_t fidx) const {
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
// get span for worker
auto wsize = worker_indptr[rank + 1] - worker_indptr[rank];
auto worker_values = global_values.subspan(worker_indptr[rank], wsize);
Expand All @@ -145,7 +143,7 @@ struct QuantileAllreduce {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
MetaInfo const& info,
Context const *, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches) {
Expand Down Expand Up @@ -206,7 +204,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo const& info) {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || info.IsColumnSplit()) {
Expand Down Expand Up @@ -274,16 +272,15 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllReduce(
MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t>* p_num_cuts) {
Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
monitor_.Start(__func__);

size_t n_columns = sketches_.size();
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";

AllreduceCategories(info);
AllreduceCategories(ctx, info);

auto& num_cuts = *p_num_cuts;
CHECK_EQ(num_cuts.size(), 0);
Expand Down Expand Up @@ -324,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);

std::vector<typename WQSketch::Entry> global_sketches;
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
this->GatherSketchInfo(ctx, info, reduced, &worker_segments, &sketches_scan, &global_sketches);

std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);

Expand Down Expand Up @@ -383,11 +380,12 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const &info, HistogramCuts *p_cuts) {
void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const &info,
HistogramCuts *p_cuts) {
monitor_.Start(__func__);
std::vector<typename WQSketch::SummaryContainer> reduced;
std::vector<int32_t> num_cuts;
this->AllReduce(info, &reduced, &num_cuts);
this->AllReduce(ctx, info, &reduced, &num_cuts);

p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
Expand Down Expand Up @@ -496,5 +494,4 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
});
monitor_.Stop(__func__);
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
13 changes: 5 additions & 8 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/span.h"

namespace xgboost {
namespace common {

namespace xgboost::common {
using WQSketch = HostSketchContainer::WQSketch;
using SketchEntry = WQSketch::Entry;

Expand Down Expand Up @@ -501,7 +499,7 @@ void SketchContainer::FixError() {
});
}

void SketchContainer::AllReduce(bool is_column_split) {
void SketchContainer::AllReduce(Context const*, bool is_column_split) {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
auto world = collective::GetWorldSize();
if (world == 1 || is_column_split) {
Expand Down Expand Up @@ -582,13 +580,13 @@ struct InvalidCatOp {
};
} // anonymous namespace

void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal));
p_cuts->min_vals_.Resize(num_columns_);

// Sync between workers.
this->AllReduce(is_column_split);
this->AllReduce(ctx, is_column_split);

// Prune to final number of bins.
this->Prune(num_bins_ + 1);
Expand Down Expand Up @@ -731,5 +729,4 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
p_cuts->SetCategorical(this->has_categorical_, max_cat);
timer_.Stop(__func__);
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
4 changes: 2 additions & 2 deletions src/common/quantile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class SketchContainer {
Span<SketchEntry const> that);

/* \brief Merge quantiles from other GPU workers. */
void AllReduce(bool is_column_split);
void AllReduce(Context const* ctx, bool is_column_split);
/* \brief Create the final histogram cut values. */
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
void MakeCuts(Context const* ctx, HistogramCuts* cuts, bool is_column_split);

Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};
Expand Down
9 changes: 5 additions & 4 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,13 +827,14 @@ class SketchContainerImpl {
return group_ind;
}
// Gather sketches from all workers.
void GatherSketchInfo(MetaInfo const& info,
void GatherSketchInfo(Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<bst_row_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches);
// Merge sketches from all workers.
void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
void AllReduce(Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t> *p_num_cuts);

template <typename Batch, typename IsValid>
Expand Down Expand Up @@ -887,11 +888,11 @@ class SketchContainerImpl {
/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});

void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);
void MakeCuts(Context const *ctx, MetaInfo const &info, HistogramCuts *cuts);

private:
// Merge all categories from other workers.
void AllreduceCategories(MetaInfo const& info);
void AllreduceCategories(Context const* ctx, MetaInfo const& info);
};

class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
Expand Down
2 changes: 1 addition & 1 deletion src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}

void MetaInfo::SynchronizeNumberOfColumns() {
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
Expand Down
12 changes: 6 additions & 6 deletions src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur

namespace {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
Expand Down Expand Up @@ -193,7 +193,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
Info().SynchronizeNumberOfColumns();
Info().SynchronizeNumberOfColumns(ctx);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
Expand All @@ -213,9 +213,9 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
while (iter.Next()) {
if (!p_sketch) {
h_ft = proxy->Info().feature_types.ConstHostVector();
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty()});
SyncFeatureType(ctx, &h_ft);
p_sketch = std::make_unique<common::HostSketchContainer>(ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty());
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
Expand All @@ -230,7 +230,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
CHECK_EQ(accumulated_rows, Info().num_row_);

CHECK(p_sketch);
p_sketch->MakeCuts(Info(), &cuts);
p_sketch->MakeCuts(ctx, Info(), &cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), n_features);
Expand Down
4 changes: 2 additions & 2 deletions src/data/iterative_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
sketch_containers.clear();
sketch_containers.shrink_to_fit();

final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
final_sketch.MakeCuts(ctx, &cuts, this->info_.IsColumnSplit());
} else {
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
}
Expand Down Expand Up @@ -167,7 +167,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,

iter.Reset();
// Synchronise worker columns
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(ctx);
}

BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
ReindexFeatures(&ctx);
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);

if (adapter->NumRows() == kAdapterUnknownSize) {
using IteratorAdapterT =
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);

this->fmat_ctx_ = ctx;
}
Expand Down
2 changes: 1 addition & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;

info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);
CHECK_NE(info_.num_col_, 0);

fmat_ctx_ = ctx;
Expand Down
Loading

0 comments on commit 06bdc15

Please sign in to comment.