Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Pass context to various functions. #9772

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class MetaInfo {
* in vertical federated learning, since each worker loads its own list of columns,
* we need to sum them.
*/
void SynchronizeNumberOfColumns();
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
Expand Down
12 changes: 6 additions & 6 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,20 +582,20 @@ auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOL
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device()};
}

template <typename T, typename... S>
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
LINALG_HD auto MakeTensorView(DeviceOrd device, common::Span<T, ext> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, device};
}

template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T, ext> data, S &&...shape) {
return MakeTensorView(ctx->Device(), data, std::forward<S>(shape)...);
}

template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T> data, S &&...shape) {
template <typename T, decltype(common::dynamic_extent) ext, typename... S>
auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->Device(), order};
Expand Down
6 changes: 3 additions & 3 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace {
auto stub = fed->Handle();

BroadcastRequest request;
request.set_sequence_number(*sequence_number++);
request.set_sequence_number((*sequence_number)++);
request.set_rank(comm.Rank());
if (comm.Rank() != root) {
request.set_send_buffer(nullptr, 0);
Expand Down Expand Up @@ -90,9 +90,9 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,

Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
if (comm.World() == 1) {
return Success();
}
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
Expand Down
3 changes: 2 additions & 1 deletion src/collective/comm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <sstream> // for stringstream
#include <vector> // for vector

#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "broadcast.h" // for Broadcast
Expand Down Expand Up @@ -60,7 +61,7 @@ Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{dh::DefaultStream()} {
stream_{ctx->CUDACtx()->Stream()} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();
Expand Down
2 changes: 1 addition & 1 deletion src/collective/comm_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ CommGroup::CommGroup()
}

std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
static std::unique_ptr<collective::CommGroup> sptr;
static thread_local std::unique_ptr<collective::CommGroup> sptr;
if (!sptr) {
Json config{Null{}};
sptr.reset(CommGroup::Create(config));
Expand Down
3 changes: 2 additions & 1 deletion src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
thread_local std::unique_ptr<cub::CachingDeviceAllocator> allocator{
std::make_unique<cub::CachingDeviceAllocator>(2, 9, 29)};
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
for (auto const &page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(m->Info(), &out);
container.MakeCuts(ctx, m->Info(), &out);
} else {
SortedSketchContainer container{ctx,
max_bins,
Expand All @@ -61,7 +61,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(m->Info(), &out);
container.MakeCuts(ctx, m->Info(), &out);
}

return out;
Expand Down
2 changes: 1 addition & 1 deletion src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
}
}

sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
sketch_container.MakeCuts(ctx, &cuts, p_fmat->Info().IsColumnSplit());
return cuts;
}
} // namespace xgboost::common
27 changes: 12 additions & 15 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
#include "categorical.h"
#include "hist_util.h"

namespace xgboost {
namespace common {

namespace xgboost::common {
template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
std::vector<bst_row_t> columns_size,
Expand Down Expand Up @@ -129,7 +127,7 @@ struct QuantileAllreduce {
* \param rank rank of target worker
* \param fidx feature idx
*/
auto Values(int32_t rank, bst_feature_t fidx) const {
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
// get span for worker
auto wsize = worker_indptr[rank + 1] - worker_indptr[rank];
auto worker_values = global_values.subspan(worker_indptr[rank], wsize);
Expand All @@ -145,7 +143,7 @@ struct QuantileAllreduce {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
MetaInfo const& info,
Context const *, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches) {
Expand Down Expand Up @@ -206,7 +204,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo const& info) {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || info.IsColumnSplit()) {
Expand Down Expand Up @@ -274,16 +272,15 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllReduce(
MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t>* p_num_cuts) {
Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
monitor_.Start(__func__);

size_t n_columns = sketches_.size();
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";

AllreduceCategories(info);
AllreduceCategories(ctx, info);

auto& num_cuts = *p_num_cuts;
CHECK_EQ(num_cuts.size(), 0);
Expand Down Expand Up @@ -324,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);

std::vector<typename WQSketch::Entry> global_sketches;
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
this->GatherSketchInfo(ctx, info, reduced, &worker_segments, &sketches_scan, &global_sketches);

std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);

Expand Down Expand Up @@ -383,11 +380,12 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const &info, HistogramCuts *p_cuts) {
void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const &info,
HistogramCuts *p_cuts) {
monitor_.Start(__func__);
std::vector<typename WQSketch::SummaryContainer> reduced;
std::vector<int32_t> num_cuts;
this->AllReduce(info, &reduced, &num_cuts);
this->AllReduce(ctx, info, &reduced, &num_cuts);

p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
Expand Down Expand Up @@ -496,5 +494,4 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
});
monitor_.Stop(__func__);
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
13 changes: 5 additions & 8 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
#include "transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/span.h"

namespace xgboost {
namespace common {

namespace xgboost::common {
using WQSketch = HostSketchContainer::WQSketch;
using SketchEntry = WQSketch::Entry;

Expand Down Expand Up @@ -501,7 +499,7 @@ void SketchContainer::FixError() {
});
}

void SketchContainer::AllReduce(bool is_column_split) {
void SketchContainer::AllReduce(Context const*, bool is_column_split) {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
auto world = collective::GetWorldSize();
if (world == 1 || is_column_split) {
Expand Down Expand Up @@ -582,13 +580,13 @@ struct InvalidCatOp {
};
} // anonymous namespace

void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal));
p_cuts->min_vals_.Resize(num_columns_);

// Sync between workers.
this->AllReduce(is_column_split);
this->AllReduce(ctx, is_column_split);

// Prune to final number of bins.
this->Prune(num_bins_ + 1);
Expand Down Expand Up @@ -731,5 +729,4 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
p_cuts->SetCategorical(this->has_categorical_, max_cat);
timer_.Stop(__func__);
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
4 changes: 2 additions & 2 deletions src/common/quantile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class SketchContainer {
Span<SketchEntry const> that);

/* \brief Merge quantiles from other GPU workers. */
void AllReduce(bool is_column_split);
void AllReduce(Context const* ctx, bool is_column_split);
/* \brief Create the final histogram cut values. */
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
void MakeCuts(Context const* ctx, HistogramCuts* cuts, bool is_column_split);

Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};
Expand Down
9 changes: 5 additions & 4 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,13 +827,14 @@ class SketchContainerImpl {
return group_ind;
}
// Gather sketches from all workers.
void GatherSketchInfo(MetaInfo const& info,
void GatherSketchInfo(Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<bst_row_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches);
// Merge sketches from all workers.
void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
void AllReduce(Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t> *p_num_cuts);

template <typename Batch, typename IsValid>
Expand Down Expand Up @@ -887,11 +888,11 @@ class SketchContainerImpl {
/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});

void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);
void MakeCuts(Context const *ctx, MetaInfo const &info, HistogramCuts *cuts);

private:
// Merge all categories from other workers.
void AllreduceCategories(MetaInfo const& info);
void AllreduceCategories(Context const* ctx, MetaInfo const& info);
};

class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
Expand Down
2 changes: 1 addition & 1 deletion src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}

void MetaInfo::SynchronizeNumberOfColumns() {
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
Expand Down
12 changes: 6 additions & 6 deletions src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur

namespace {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
Expand Down Expand Up @@ -193,7 +193,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
Info().SynchronizeNumberOfColumns();
Info().SynchronizeNumberOfColumns(ctx);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
Expand All @@ -213,9 +213,9 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
while (iter.Next()) {
if (!p_sketch) {
h_ft = proxy->Info().feature_types.ConstHostVector();
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty()});
SyncFeatureType(ctx, &h_ft);
p_sketch = std::make_unique<common::HostSketchContainer>(ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty());
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
Expand All @@ -230,7 +230,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
CHECK_EQ(accumulated_rows, Info().num_row_);

CHECK(p_sketch);
p_sketch->MakeCuts(Info(), &cuts);
p_sketch->MakeCuts(ctx, Info(), &cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), n_features);
Expand Down
4 changes: 2 additions & 2 deletions src/data/iterative_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
sketch_containers.clear();
sketch_containers.shrink_to_fit();

final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
final_sketch.MakeCuts(ctx, &cuts, this->info_.IsColumnSplit());
} else {
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
}
Expand Down Expand Up @@ -167,7 +167,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,

iter.Reset();
// Synchronise worker columns
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(ctx);
}

BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
ReindexFeatures(&ctx);
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);

if (adapter->NumRows() == kAdapterUnknownSize) {
using IteratorAdapterT =
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);

this->fmat_ctx_ = ctx;
}
Expand Down
2 changes: 1 addition & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;

info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);
CHECK_NE(info_.num_col_, 0);

fmat_ctx_ = ctx;
Expand Down
Loading
Loading