From 7b4e47d78b52ce50ca0c32d60c5dffb390f67dff Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 10 Apr 2019 02:47:46 +0800 Subject: [PATCH] Add a small test. --- src/tree/updater_gpu_hist.cu | 6 ++++-- tests/cpp/tree/test_gpu_hist.cu | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 92bfe2952bf4..a44f74a1ba0a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -365,16 +365,18 @@ __global__ void EvaluateSplitKernel( * * \summary Data storage for node histograms on device. Automatically expands. * + * \tparam GradientSumT histogram entry type. + * \tparam kStopGrowingSize Do not grow beyond this size + * * \author Rory * \date 28/07/2018 */ -template +template class DeviceHistogram { private: /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map_; thrust::device_vector data_; - static constexpr size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size int n_bins_; int device_id_; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index d7e4b2654e79..3dab672910b4 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -20,6 +20,32 @@ namespace xgboost { namespace tree { +TEST(GpuHist, DeviceHistogram) { + // Ensures that node allocates correctly after reaching `kStopGrowingSize`. + dh::SaveCudaContext{ + [&]() { + dh::safe_cuda(cudaSetDevice(0)); + constexpr size_t kNbins = 128; + constexpr size_t kNNodes = 4; + constexpr size_t kStopGrowing = kNNodes * kNbins * sizeof(GradientPairPrecise); + DeviceHistogram histogram; + histogram.Init(0, kNbins); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistogram(i); + } + histogram.Reset(); + ASSERT_EQ(histogram.Data().size() * 8u, kStopGrowing); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistogram(i); + } + for (int i = 0; i < kNNodes; ++i) { + ASSERT_TRUE(histogram.HistogramExists(i)); + } + } + }; + +} + template void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, bst_float sparsity=0) {