diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index 38d3a3c27d0a..a3e3ca3c50f0 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -154,10 +154,13 @@ bool HostDeviceVector::DeviceCanAccess(int device, GPUAccess access) const { } template -void HostDeviceVector::Reshard(const GPUDistribution& distribution) const { } +void HostDeviceVector::Shard(const GPUDistribution& distribution) const { } template -void HostDeviceVector::Reshard(GPUSet devices) const { } +void HostDeviceVector::Shard(GPUSet devices) const { } + +template +void Reshard(const GPUDistribution &distribution, bool preserve) { } // explicit instantiations are required, as HostDeviceVector isn't header-only template class HostDeviceVector; diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 5e7634501444..bb4ff95fdb90 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -318,7 +318,7 @@ struct HostDeviceVectorImpl { // Data is on device; if (distribution_ != other->distribution_) { distribution_ = GPUDistribution(); - Reshard(other->Distribution()); + Shard(other->Distribution()); size_d_ = other->size_d_; } dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { @@ -358,19 +358,27 @@ struct HostDeviceVectorImpl { return data_h_; } - void Reshard(const GPUDistribution& distribution) { + void Shard(const GPUDistribution& distribution) { if (distribution_ == distribution) { return; } - CHECK(distribution_.IsEmpty() || distribution.IsEmpty()); - if (distribution.IsEmpty()) { - LazySyncHost(GPUAccess::kWrite); - } + CHECK(distribution_.IsEmpty()); distribution_ = distribution; InitShards(); } - void Reshard(GPUSet new_devices) { + void Shard(GPUSet new_devices) { if (distribution_.Devices() == new_devices) { return; } - Reshard(GPUDistribution::Block(new_devices)); + Shard(GPUDistribution::Block(new_devices)); + } + + void Reshard(const GPUDistribution &distribution, bool preserve) { + if (distribution_ == distribution) { return; } + if (preserve) { + LazySyncHost(GPUAccess::kWrite); + } + distribution_ = distribution; + shards_.clear(); + perm_h_.Grant(kWrite); + InitShards(); } void Resize(size_t new_size, T v) { @@ -586,13 +594,18 @@ bool HostDeviceVector::DeviceCanAccess(int device, GPUAccess access) const { } template -void HostDeviceVector::Reshard(GPUSet new_devices) const { - impl_->Reshard(new_devices); +void HostDeviceVector::Shard(GPUSet new_devices) const { + impl_->Shard(new_devices); +} + +template +void HostDeviceVector::Shard(const GPUDistribution &distribution) const { + impl_->Shard(distribution); } template -void HostDeviceVector::Reshard(const GPUDistribution& distribution) const { - impl_->Reshard(distribution); +void HostDeviceVector::Reshard(const GPUDistribution &distribution, bool preserve) { + impl_->Reshard(distribution, preserve); } template diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index 425cbff53e36..92a7ed7e3f1c 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -14,7 +14,7 @@ * Initialization/Allocation:
* One can choose to initialize the vector on CPU or GPU during constructor. * (use the 'devices' argument) Or, can choose to use the 'Resize' method to - * allocate/resize memory explicitly, and use the 'Reshard' method + * allocate/resize memory explicitly, and use the 'Shard' method * to specify the devices. * * Accessing underlying data:
@@ -98,6 +98,8 @@ class GPUDistribution { offsets_(std::move(offsets)) {} public: + static GPUDistribution Empty() { return GPUDistribution(); } + static GPUDistribution Block(GPUSet devices) { return GPUDistribution(devices); } static GPUDistribution Overlap(GPUSet devices, int overlap) { @@ -250,11 +252,15 @@ class HostDeviceVector { /*! * \brief Specify memory distribution. - * - * If GPUSet::Empty() is used, all data will be drawn back to CPU. */ - void Reshard(const GPUDistribution& distribution) const; - void Reshard(GPUSet devices) const; + void Shard(const GPUDistribution &distribution) const; + void Shard(GPUSet devices) const; + + /*! + * \brief Change memory distribution. + */ + void Reshard(const GPUDistribution &distribution, bool preserve=true); + void Resize(size_t new_size, T v = T()); private: diff --git a/src/common/transform.h b/src/common/transform.h index 841c56cb7cc4..62ef433efe50 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -57,13 +57,13 @@ class Transform { template struct Evaluator { public: - Evaluator(Functor func, Range range, GPUSet devices, bool reshard) : + Evaluator(Functor func, Range range, GPUSet devices, bool shard) : func_(func), range_{std::move(range)}, - reshard_{reshard}, + shard_{shard}, distribution_{std::move(GPUDistribution::Block(devices))} {} Evaluator(Functor func, Range range, GPUDistribution dist, - bool reshard) : - func_(func), range_{std::move(range)}, reshard_{reshard}, + bool shard) : + func_(func), range_{std::move(range)}, shard_{shard}, distribution_{std::move(dist)} {} /*! @@ -106,25 +106,25 @@ class Transform { return Span {_vec->ConstHostPointer(), static_cast::index_type>(_vec->Size())}; } - // Recursive unpack for Reshard. + // Recursive unpack for Shard. template - void UnpackReshard(GPUDistribution dist, const HostDeviceVector* vector) const { - vector->Reshard(dist); + void UnpackShard(GPUDistribution dist, const HostDeviceVector *vector) const { + vector->Shard(dist); } template - void UnpackReshard(GPUDistribution dist, - const HostDeviceVector* _vector, - const HostDeviceVector*... _vectors) const { - _vector->Reshard(dist); - UnpackReshard(dist, _vectors...); + void UnpackShard(GPUDistribution dist, + const HostDeviceVector *_vector, + const HostDeviceVector *... _vectors) const { + _vector->Shard(dist); + UnpackShard(dist, _vectors...); } #if defined(__CUDACC__) template ::type* = nullptr, typename... HDV> void LaunchCUDA(Functor _func, HDV*... _vectors) const { - if (reshard_) - UnpackReshard(distribution_, _vectors...); + if (shard_) + UnpackShard(distribution_, _vectors...); GPUSet devices = distribution_.Devices(); size_t range_size = *range_.end() - *range_.begin(); @@ -170,8 +170,8 @@ class Transform { Functor func_; /*! \brief Range object specifying parallel threads index range. */ Range range_; - /*! \brief Whether resharding for vectors is required. */ - bool reshard_; + /*! \brief Whether sharding for vectors is required. */ + bool shard_; GPUDistribution distribution_; }; @@ -187,19 +187,19 @@ class Transform { * \param range Range object specifying parallel threads index range. * \param devices GPUSet specifying GPUs to use, when compiling for CPU, * this should be GPUSet::Empty(). - * \param reshard Whether Reshard for HostDeviceVector is needed. + * \param shard Whether Shard for HostDeviceVector is needed. */ template static Evaluator Init(Functor func, Range const range, GPUSet const devices, - bool const reshard = true) { - return Evaluator {func, std::move(range), std::move(devices), reshard}; + bool const shard = true) { + return Evaluator {func, std::move(range), std::move(devices), shard}; } template static Evaluator Init(Functor func, Range const range, GPUDistribution const dist, - bool const reshard = true) { - return Evaluator {func, std::move(range), std::move(dist), reshard}; + bool const shard = true) { + return Evaluator {func, std::move(range), std::move(dist), shard}; } }; diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index a9221be849bf..3ecc7915837b 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -111,9 +111,9 @@ class ElementWiseMetricsReduction { allocators_.clear(); allocators_.resize(devices.Size()); } - preds.Reshard(devices); - labels.Reshard(devices); - weights.Reshard(devices); + preds.Shard(devices); + labels.Shard(devices); + weights.Shard(devices); std::vector res_per_device(devices.Size()); #pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 88af0014ed5a..4723b9371d91 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -134,9 +134,9 @@ class MultiClassMetricsReduction { allocators_.clear(); allocators_.resize(devices.Size()); } - preds.Reshard(GPUDistribution::Granular(devices, n_class)); - labels.Reshard(devices); - weights.Reshard(devices); + preds.Shard(GPUDistribution::Granular(devices, n_class)); + labels.Shard(devices); + weights.Shard(devices); std::vector res_per_device(devices.Size()); #pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index dcea023702b7..79ead482d709 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -39,7 +39,7 @@ struct SoftmaxMultiClassParam : public dmlc::Parameter { .describe("gpu to use for objective function evaluation"); } }; -// TODO(trivialfis): Currently the resharding in softmax is less than ideal +// TODO(trivialfis): Currently the sharding in softmax is less than ideal // due to repeated copying data between CPU and GPUs. Maybe we just use single // GPU? class SoftmaxMultiClassObj : public ObjFunction { @@ -63,11 +63,11 @@ class SoftmaxMultiClassObj : public ObjFunction { const int nclass = param_.num_class; const auto ndata = static_cast(preds.Size() / nclass); - out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass)); - info.labels_.Reshard(GPUDistribution::Block(devices_)); - info.weights_.Reshard(GPUDistribution::Block(devices_)); - preds.Reshard(GPUDistribution::Granular(devices_, nclass)); - label_correct_.Reshard(GPUDistribution::Block(devices_)); + out_gpair->Shard(GPUDistribution::Granular(devices_, nclass)); + info.labels_.Shard(GPUDistribution::Block(devices_)); + info.weights_.Shard(GPUDistribution::Block(devices_)); + preds.Shard(GPUDistribution::Granular(devices_, nclass)); + label_correct_.Shard(GPUDistribution::Block(devices_)); out_gpair->Resize(preds.Size()); label_correct_.Fill(1); @@ -136,8 +136,8 @@ class SoftmaxMultiClassObj : public ObjFunction { common::Range{0, ndata}, GPUDistribution::Granular(devices_, nclass)) .Eval(io_preds); } else { - io_preds->Reshard(GPUDistribution::Granular(devices_, nclass)); - max_preds_.Reshard(GPUDistribution::Block(devices_)); + io_preds->Shard(GPUDistribution::Granular(devices_, nclass)); + max_preds_.Shard(GPUDistribution::Block(devices_)); common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _preds, diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 579bf5844b35..0fcd0270ec66 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -327,11 +327,11 @@ class GPUPredictor : public xgboost::Predictor { for (const auto &batch : dmat->GetRowBatches()) { CHECK_EQ(i_batch, 0) << "External memory not supported"; - // out_preds have been resharded and resized in InitOutPredictions() - batch.offset.Reshard(GPUDistribution::Overlap(devices_, 1)); + // out_preds have been sharded and resized in InitOutPredictions() + batch.offset.Shard(GPUDistribution::Overlap(devices_, 1)); std::vector device_offsets; DeviceOffsets(batch.offset, &device_offsets); - batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); + batch.data.Shard(GPUDistribution::Explicit(devices_, device_offsets)); dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments, h_nodes, tree_begin, tree_end); @@ -373,7 +373,7 @@ class GPUPredictor : public xgboost::Predictor { size_t n_classes = model.param.num_output_group; size_t n = n_classes * info.num_row_; const HostDeviceVector& base_margin = info.base_margin_; - out_preds->Reshard(GPUDistribution::Granular(devices_, n_classes)); + out_preds->Shard(GPUDistribution::Granular(devices_, n_classes)); out_preds->Resize(n); if (base_margin.Size() != 0) { CHECK_EQ(out_preds->Size(), n); @@ -392,7 +392,7 @@ class GPUPredictor : public xgboost::Predictor { const HostDeviceVector& y = it->second.predictions; if (y.Size() != 0) { monitor_.StartCuda("PredictFromCache"); - out_preds->Reshard(y.Distribution()); + out_preds->Shard(y.Distribution()); out_preds->Resize(y.Size()); out_preds->Copy(y); monitor_.StopCuda("PredictFromCache"); diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index 3348ebf7cd37..3e2f3235c77b 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -566,7 +566,7 @@ class GPUMaker : public TreeUpdater { int maxNodes_; int maxLeaves_; - // devices are only used for resharding the HostDeviceVector passed as a parameter; + // devices are only used for sharding the HostDeviceVector passed as a parameter; // the algorithm works with a single GPU only GPUSet devices_; @@ -594,7 +594,7 @@ class GPUMaker : public TreeUpdater { float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); - gpair->Reshard(devices_); + gpair->Shard(devices_); try { // build tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9d74abde2a3d..a980661e7aa9 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -836,7 +836,7 @@ struct DeviceShard { for (auto i = 0ull; i < nidxs.size(); i++) { auto nidx = nidxs[i]; auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx)); - p_feature_set->Reshard(GPUSet(device_id, 1)); + p_feature_set->Shard(GPUSet(device_id, 1)); auto d_feature_set = p_feature_set->DeviceSpan(device_id); auto d_split_candidates = d_split_candidates_all.subspan(i * num_columns, d_feature_set.size()); @@ -1527,7 +1527,7 @@ class GPUHistMakerSpecialised{ return false; } monitor_.StartCuda("UpdatePredictionCache"); - p_out_preds->Reshard(dist_.Devices()); + p_out_preds->Shard(dist_.Devices()); dh::ExecuteIndexShards( &shards_, [&](int idx, std::unique_ptr>& shard) { diff --git a/tests/cpp/common/test_host_device_vector.cu b/tests/cpp/common/test_host_device_vector.cu index bac9b026a043..1abcdadf0d68 100644 --- a/tests/cpp/common/test_host_device_vector.cu +++ b/tests/cpp/common/test_host_device_vector.cu @@ -23,7 +23,7 @@ void InitHostDeviceVector(size_t n, const GPUDistribution& distribution, HostDeviceVector *v) { // create the vector GPUSet devices = distribution.Devices(); - v->Reshard(distribution); + v->Shard(distribution); v->Resize(n); ASSERT_EQ(v->Size(), n); @@ -178,7 +178,7 @@ TEST(HostDeviceVector, TestCopy) { SetCudaSetDeviceHandler(nullptr); } -TEST(HostDeviceVector, Reshard) { +TEST(HostDeviceVector, Shard) { std::vector h_vec (2345); for (size_t i = 0; i < h_vec.size(); ++i) { h_vec[i] = i; @@ -186,12 +186,12 @@ TEST(HostDeviceVector, Reshard) { HostDeviceVector vec (h_vec); auto devices = GPUSet::Range(0, 1); - vec.Reshard(devices); + vec.Shard(devices); ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); ASSERT_EQ(vec.Size(), h_vec.size()); auto span = vec.DeviceSpan(0); // sync to device - vec.Reshard(GPUSet::Empty()); // pull back to cpu, empty devices. + vec.Reshard(GPUDistribution::Empty()); // pull back to cpu, empty devices. ASSERT_EQ(vec.Size(), h_vec.size()); ASSERT_TRUE(vec.Devices().IsEmpty()); @@ -199,9 +199,48 @@ TEST(HostDeviceVector, Reshard) { ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin())); } +TEST(HostDeviceVector, Reshard) { + std::vector h_vec (2345); + for (size_t i = 0; i < h_vec.size(); ++i) { + h_vec[i] = i; + } + HostDeviceVector vec (h_vec); + auto devices = GPUSet::Range(0, 1); + + vec.Shard(devices); + ASSERT_EQ(vec.DeviceSize(0), h_vec.size()); + ASSERT_EQ(vec.Size(), h_vec.size()); + auto span = vec.DeviceSpan(0); // sync to device + PlusOne(&vec); + + // GPU data is preserved. + vec.Reshard(GPUDistribution::Empty()); + ASSERT_EQ(vec.Size(), h_vec.size()); + ASSERT_TRUE(vec.Devices().IsEmpty()); + + auto h_vec_1 = vec.HostVector(); + for (size_t i = 0; i < h_vec_1.size(); ++i) { + ASSERT_EQ(h_vec_1.at(i), i + 1); + } + + vec.Reshard(GPUDistribution::Block(devices)); + span = vec.DeviceSpan(0); // sync to device + PlusOne(&vec); + + vec.Reshard(GPUDistribution::Empty(), /*preserve=*/false); + ASSERT_EQ(vec.Size(), h_vec.size()); + ASSERT_TRUE(vec.Devices().IsEmpty()); + + auto h_vec_2 = vec.HostVector(); + for (size_t i = 0; i < h_vec_2.size(); ++i) { + // The second `PlusOne()` has no effect. + ASSERT_EQ(h_vec_2.at(i), i + 1); + } +} + TEST(HostDeviceVector, Span) { HostDeviceVector vec {1.0f, 2.0f, 3.0f, 4.0f}; - vec.Reshard(GPUSet{0, 1}); + vec.Shard(GPUSet{0, 1}); auto span = vec.DeviceSpan(0); ASSERT_EQ(vec.DeviceSize(0), span.size()); ASSERT_EQ(vec.DevicePointer(0), span.data()); @@ -212,7 +251,7 @@ TEST(HostDeviceVector, Span) { // Multi-GPUs' test #if defined(XGBOOST_USE_NCCL) -TEST(HostDeviceVector, MGPU_Reshard) { +TEST(HostDeviceVector, MGPU_Shard) { auto devices = GPUSet::AllVisible(); if (devices.Size() < 2) { LOG(WARNING) << "Not testing in multi-gpu environment."; @@ -229,7 +268,7 @@ TEST(HostDeviceVector, MGPU_Reshard) { std::vector devices_size (devices.Size()); // From CPU to GPUs. - vec.Reshard(devices); + vec.Shard(devices); size_t total_size = 0; for (size_t i = 0; i < devices.Size(); ++i) { total_size += vec.DeviceSize(i); @@ -238,16 +277,16 @@ TEST(HostDeviceVector, MGPU_Reshard) { ASSERT_EQ(total_size, h_vec.size()); ASSERT_EQ(total_size, vec.Size()); - // Reshard from devices to devices with different distribution. + // Shard from devices to devices with different distribution. EXPECT_ANY_THROW( - vec.Reshard(GPUDistribution::Granular(devices, 12))); + vec.Shard(GPUDistribution::Granular(devices, 12))); // All data is drawn back to CPU - vec.Reshard(GPUSet::Empty()); + vec.Reshard(GPUDistribution::Empty()); ASSERT_TRUE(vec.Devices().IsEmpty()); ASSERT_EQ(vec.Size(), h_vec.size()); - vec.Reshard(GPUDistribution::Granular(devices, 12)); + vec.Shard(GPUDistribution::Granular(devices, 12)); total_size = 0; for (size_t i = 0; i < devices.Size(); ++i) { total_size += vec.DeviceSize(i); @@ -256,6 +295,67 @@ TEST(HostDeviceVector, MGPU_Reshard) { ASSERT_EQ(total_size, h_vec.size()); ASSERT_EQ(total_size, vec.Size()); } + +TEST(HostDeviceVector, MGPU_Reshard) { + auto devices = GPUSet::AllVisible(); + if (devices.Size() < 2) { + LOG(WARNING) << "Not testing in multi-gpu environment."; + return; + } + + std::vector h_vec (2345); + for (size_t i = 0; i < h_vec.size(); ++i) { + h_vec[i] = i; + } + HostDeviceVector vec (h_vec); + + // Data size for each device. + std::vector devices_size (devices.Size()); + + // From CPU to GPUs. + vec.Shard(devices); + for (size_t i = 0; i < devices.Size(); ++i) { + auto span = vec.DeviceSpan(i); // sync to device + } + PlusOne(&vec); + + // Reshard is allowed for already sharded vector. + vec.Reshard(GPUDistribution::Overlap(devices, 7)); + size_t total_size = 0; + for (size_t i = 0; i < devices.Size(); ++i) { + total_size += vec.DeviceSize(i); + devices_size[i] = vec.DeviceSize(i); + } + size_t overlap = 7 * (devices.Size() - 1); + ASSERT_EQ(total_size, h_vec.size() + overlap); + ASSERT_EQ(total_size, vec.Size() + overlap); + + auto h_vec_1 = vec.HostVector(); + for (size_t i = 0; i < h_vec_1.size(); ++i) { + ASSERT_EQ(h_vec_1.at(i), i + 1); + } + + for (size_t i = 0; i < devices.Size(); ++i) { + auto span = vec.DeviceSpan(i); // sync to device + } + PlusOne(&vec); + + vec.Reshard(GPUDistribution::Overlap(devices, 11), /*preserve=*/false); + total_size = 0; + for (size_t i = 0; i < devices.Size(); ++i) { + total_size += vec.DeviceSize(i); + devices_size[i] = vec.DeviceSize(i); + } + overlap = 11 * (devices.Size() - 1); + ASSERT_EQ(total_size, h_vec.size() + overlap); + ASSERT_EQ(total_size, vec.Size() + overlap); + + auto h_vec_2 = vec.HostVector(); + for (size_t i = 0; i < h_vec_2.size(); ++i) { + // The second `PlusOne()` has no effect. + ASSERT_EQ(h_vec_2.at(i), i + 1); + } +} #endif } // namespace common diff --git a/tests/cpp/common/test_transform_range.cu b/tests/cpp/common/test_transform_range.cu index 45f1f312253b..29172937fec2 100644 --- a/tests/cpp/common/test_transform_range.cu +++ b/tests/cpp/common/test_transform_range.cu @@ -22,10 +22,10 @@ TEST(Transform, MGPU_Basic) { GPUDistribution::Block(GPUSet::Empty())}; out_vec.Fill(0); - in_vec.Reshard(GPUDistribution::Granular(devices, 8)); - out_vec.Reshard(GPUDistribution::Block(devices)); + in_vec.Shard(GPUDistribution::Granular(devices, 8)); + out_vec.Shard(GPUDistribution::Block(devices)); - // Granularity is different, resharding will throw. + // Granularity is different, sharding will throw. EXPECT_ANY_THROW( Transform<>::Init(TestTransformRange{}, Range{0, size}, devices) .Eval(&out_vec, &in_vec));