diff --git a/src/tree/gpu_hist/feature_groups.cu b/src/tree/gpu_hist/feature_groups.cu new file mode 100644 index 000000000000..5a2c8ee6cbd8 --- /dev/null +++ b/src/tree/gpu_hist/feature_groups.cu @@ -0,0 +1,64 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ + +#include +#include +#include + +#include "feature_groups.cuh" + +#include "../../common/device_helpers.cuh" +#include "../../common/hist_util.h" + +namespace xgboost { +namespace tree { + +FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, + size_t shm_size, size_t bin_size) { + // Only use a single feature group for sparse matrices. + bool single_group = !is_dense; + if (single_group) { + InitSingle(cuts); + return; + } + + std::vector& feature_segments_h = feature_segments.HostVector(); + std::vector& bin_segments_h = bin_segments.HostVector(); + feature_segments_h.push_back(0); + bin_segments_h.push_back(0); + + const std::vector& cut_ptrs = cuts.Ptrs(); + int max_shmem_bins = shm_size / bin_size; + max_group_bins = 0; + + for (size_t i = 2; i < cut_ptrs.size(); ++i) { + int last_start = bin_segments_h.back(); + if (cut_ptrs[i] - last_start > max_shmem_bins) { + feature_segments_h.push_back(i - 1); + bin_segments_h.push_back(cut_ptrs[i - 1]); + max_group_bins = std::max(max_group_bins, + bin_segments_h.back() - last_start); + } + } + feature_segments_h.push_back(cut_ptrs.size() - 1); + bin_segments_h.push_back(cut_ptrs.back()); + max_group_bins = std::max(max_group_bins, + bin_segments_h.back() - + bin_segments_h[bin_segments_h.size() - 2]); +} + +void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) { + std::vector& feature_segments_h = feature_segments.HostVector(); + feature_segments_h.push_back(0); + feature_segments_h.push_back(cuts.Ptrs().size() - 1); + + std::vector& bin_segments_h = bin_segments.HostVector(); + bin_segments_h.push_back(0); + bin_segments_h.push_back(cuts.TotalBins()); + + max_group_bins = cuts.TotalBins(); +} + +} // namespace tree +} // namespace xgboost diff --git a/src/tree/gpu_hist/feature_groups.cuh b/src/tree/gpu_hist/feature_groups.cuh new file mode 100644 index 000000000000..3af230c2ccf6 --- /dev/null +++ b/src/tree/gpu_hist/feature_groups.cuh @@ -0,0 +1,119 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#ifndef FEATURE_GROUPS_CUH_ +#define FEATURE_GROUPS_CUH_ + +#include +#include + +namespace xgboost { + +// Forward declarations. +namespace common { +class HistogramCuts; +} // namespace common + +namespace tree { + +/** \brief FeatureGroup is a feature group. It is defined by a range of + consecutive feature indices, and also contains a range of all bin indices + associated with those features. */ +struct FeatureGroup { + __host__ __device__ FeatureGroup(int start_feature_, int num_features_, + int start_bin_, int num_bins_) : + start_feature(start_feature_), num_features(num_features_), + start_bin(start_bin_), num_bins(num_bins_) {} + /** The first feature of the group. */ + int start_feature; + /** The number of features in the group. */ + int num_features; + /** The first bin in the group. */ + int start_bin; + /** The number of bins in the group. */ + int num_bins; +}; + +/** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */ +struct FeatureGroupsAccessor { + FeatureGroupsAccessor(common::Span feature_segments_, + common::Span bin_segments_, int max_group_bins_) : + feature_segments(feature_segments_), bin_segments(bin_segments_), + max_group_bins(max_group_bins_) {} + + common::Span feature_segments; + common::Span bin_segments; + int max_group_bins; + + /** \brief Gets the number of feature groups. */ + __host__ __device__ int NumGroups() const { + return feature_segments.size() - 1; + } + + /** \brief Gets the information about a feature group with index i. */ + __host__ __device__ FeatureGroup operator[](int i) const { + return {feature_segments[i], feature_segments[i + 1] - feature_segments[i], + bin_segments[i], bin_segments[i + 1] - bin_segments[i]}; + } +}; + +/** \brief FeatureGroups contains information that defines a split of features + into groups. Bins of a single feature group typically fit into shared + memory, so the histogram for the features of a single group can be computed + faster. + + \notes Known limitations: + + - splitting features into groups currently works only for dense matrices, + where it is easy to get a feature value in a row by its index; for sparse + matrices, the structure contains only a single group containing all + features; + + - if a single feature requires more bins than fit into shared memory, the + histogram is computed in global memory even if there are multiple feature + groups; note that this is unlikely to occur in practice, as the default + number of bins per feature is 256, whereas a thread block with 48 KiB + shared memory can contain 3072 bins if each gradient sum component is a + 64-bit floating-point value (double) +*/ +struct FeatureGroups { + /** Group cuts for features. Size equals to (number of groups + 1). */ + HostDeviceVector feature_segments; + /** Group cuts for bins. Size equals to (number of groups + 1) */ + HostDeviceVector bin_segments; + /** Maximum number of bins in a group. Useful to compute the amount of dynamic + shared memory when launching a kernel. */ + int max_group_bins; + + /** Creates feature groups by splitting features into groups. + \param cuts Histogram cuts that given the number of bins per feature. + \param is_dense Whether the data matrix is dense. + \param shm_size Available size of shared memory per thread block (in + bytes) used to compute feature groups. + \param bin_size Size of a single bin of the histogram. */ + FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, + size_t shm_size, size_t bin_size); + + /** Creates a single feature group containing all features and bins. + \notes This is used as a fallback for sparse matrices, and is also useful + for testing. + */ + explicit FeatureGroups(const common::HistogramCuts& cuts) { + InitSingle(cuts); + } + + FeatureGroupsAccessor DeviceAccessor(int device) const { + feature_segments.SetDevice(device); + bin_segments.SetDevice(device); + return {feature_segments.ConstDeviceSpan(), bin_segments.ConstDeviceSpan(), + max_group_bins}; + } + +private: + void InitSingle(const common::HistogramCuts& cuts); +}; + +} // namespace tree +} // namespace xgboost + +#endif // FEATURE_GROUPS_CUH_ diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index edc3046d1dcb..0169833d296c 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -102,23 +102,26 @@ template GradientPair CreateRoundingFactor(common::Span gpai template __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, + FeatureGroupsAccessor feature_groups, common::Span d_ridx, GradientSumT* __restrict__ d_node_hist, const GradientPair* __restrict__ d_gpair, - size_t n_elements, GradientSumT const rounding, bool use_shared_memory_histograms) { using T = typename GradientSumT::ValueT; extern __shared__ char smem[]; + FeatureGroup group = feature_groups[blockIdx.y]; GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT if (use_shared_memory_histograms) { - dh::BlockFill(smem_arr, matrix.NumBins(), GradientSumT()); + dh::BlockFill(smem_arr, group.num_bins, GradientSumT()); __syncthreads(); } + int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride; + size_t n_elements = feature_stride * d_ridx.size(); for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / matrix.row_stride]; - int gidx = - matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride]; + int ridx = d_ridx[idx / feature_stride]; + int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + + idx % feature_stride]; if (gidx != matrix.NumBins()) { GradientSumT truncated { TruncateWithRoundingFactor(rounding.GetGrad(), d_gpair[ridx].GetGrad()), @@ -127,7 +130,8 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, // If we are not using shared memory, accumulate the values directly into // global memory GradientSumT* atomic_add_ptr = - use_shared_memory_histograms ? smem_arr : d_node_hist; + use_shared_memory_histograms ? smem_arr : d_node_hist; + gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx; dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated); } } @@ -135,18 +139,21 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, if (use_shared_memory_histograms) { // Write shared memory back to global memory __syncthreads(); - for (auto i : dh::BlockStrideRange(static_cast(0), matrix.NumBins())) { - GradientSumT truncated { - TruncateWithRoundingFactor(rounding.GetGrad(), smem_arr[i].GetGrad()), - TruncateWithRoundingFactor(rounding.GetHess(), smem_arr[i].GetHess()), + for (auto i : dh::BlockStrideRange(0, group.num_bins)) { + GradientSumT truncated{ + TruncateWithRoundingFactor(rounding.GetGrad(), + smem_arr[i].GetGrad()), + TruncateWithRoundingFactor(rounding.GetHess(), + smem_arr[i].GetHess()), }; - dh::AtomicAddGpair(d_node_hist + i, truncated); + dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); } } } template void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span d_ridx, common::Span histogram, @@ -155,7 +162,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, int device = 0; dh::safe_cuda(cudaGetDevice(&device)); int max_shared_memory = dh::MaxSharedMemoryOptin(device); - size_t smem_size = sizeof(GradientSumT) * matrix.NumBins(); + size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins; bool shared = smem_size <= max_shared_memory; smem_size = shared ? smem_size : 0; @@ -169,6 +176,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, // determine the launch configuration unsigned block_threads = shared ? 1024 : 256; + int num_groups = feature_groups.NumGroups(); int n_mps = 0; dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); int n_blocks_per_mp = 0; @@ -176,15 +184,31 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, (&n_blocks_per_mp, kernel, block_threads, smem_size)); unsigned grid_size = n_blocks_per_mp * n_mps; - auto n_elements = d_ridx.size() * matrix.row_stride; - dh::LaunchKernel {grid_size, block_threads, smem_size} ( - kernel, matrix, d_ridx, histogram.data(), gpair.data(), n_elements, - rounding, shared); + // TODO(canonizer): This is really a hack, find a better way to distribute the + // data among thread blocks. + // The intention is to generate enough thread blocks to fill the GPU, but + // avoid having too many thread blocks, as this is less efficient when the + // number of rows is low. At least one thread block per feature group is + // required. + // The number of thread blocks: + // - for num_groups <= num_groups_threshold, around grid_size * num_groups + // - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size, + // around grid_size * num_groups_threshold + // - for num_groups_threshold * grid_size <= num_groups, around num_groups + int num_groups_threshold = 4; + grid_size = common::DivRoundUp(grid_size, + common::DivRoundUp(num_groups, num_groups_threshold)); + + dh::LaunchKernel {dim3(grid_size, num_groups), block_threads, smem_size} ( + kernel, + matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding, + shared); dh::safe_cuda(cudaGetLastError()); } template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, @@ -192,6 +216,7 @@ template void BuildGradientHistogram( template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index d8673a8a5f7a..6b2675ebca05 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -4,6 +4,9 @@ #ifndef HISTOGRAM_CUH_ #define HISTOGRAM_CUH_ #include + +#include "feature_groups.cuh" + #include "../../data/ellpack_page.cuh" namespace xgboost { @@ -19,6 +22,7 @@ DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) template void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 25d2645e1032..5cbe75350402 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -26,6 +26,7 @@ #include "param.h" #include "updater_gpu_common.cuh" #include "constraints.cuh" +#include "gpu_hist/feature_groups.cuh" #include "gpu_hist/gradient_based_sampler.cuh" #include "gpu_hist/row_partitioner.cuh" #include "gpu_hist/histogram.cuh" @@ -203,6 +204,8 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; + std::unique_ptr feature_groups; + GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, bst_uint _n_rows, @@ -229,6 +232,9 @@ struct GPUHistMakerDevice { // Init histogram hist.Init(device_id, page->Cuts().TotalBins()); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); + feature_groups.reset(new FeatureGroups( + page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), + sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT @@ -372,8 +378,9 @@ struct GPUHistMakerDevice { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - BuildGradientHistogram(page->GetDeviceAccessor(device_id), gpair, d_ridx, d_node_hist, - histogram_rounding); + BuildGradientHistogram(page->GetDeviceAccessor(device_id), + feature_groups->DeviceAccessor(device_id), gpair, + d_ridx, d_node_hist, histogram_rounding); } void SubtractionTrick(int nidx_parent, int nidx_histogram, diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 23fa5ebe8c16..99cc4b835fec 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -1,4 +1,5 @@ #include +#include #include "../../helpers.h" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/histogram.cuh" @@ -7,11 +8,12 @@ namespace xgboost { namespace tree { template -void TestDeterminsticHistogram() { - size_t constexpr kBins = 24, kCols = 8, kRows = 32768, kRounds = 16; +void TestDeterministicHistogram(bool is_dense, int shm_size) { + size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; float constexpr kLower = -1e-2, kUpper = 1e2; - auto matrix = RandomDataGenerator(kRows, kCols, 0.5).GenerateDMatrix(); + float sparsity = is_dense ? 0.0f : 0.5f; + auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(); BatchParam batch_param{0, static_cast(kBins), 0}; for (auto const& batch : matrix->GetBatches(batch_param)) { @@ -20,48 +22,80 @@ void TestDeterminsticHistogram() { tree::RowPartitioner row_partitioner(0, kRows); auto ridx = row_partitioner.GetRows(0); - dh::device_vector histogram(kBins * kCols); + int num_bins = kBins * kCols; + dh::device_vector histogram(num_bins); auto d_histogram = dh::ToSpan(histogram); auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); gpair.SetDevice(0); + FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, + sizeof(Gradient)); + auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); - BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx, - d_histogram, rounding); + BuildGradientHistogram(page->GetDeviceAccessor(0), + feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), + ridx, d_histogram, rounding); + + std::vector histogram_h(num_bins); + dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(), + num_bins * sizeof(Gradient), + cudaMemcpyDeviceToHost)); for (size_t i = 0; i < kRounds; ++i) { - dh::device_vector new_histogram(kBins * kCols); - auto d_histogram = dh::ToSpan(new_histogram); + dh::device_vector new_histogram(num_bins); + auto d_new_histogram = dh::ToSpan(new_histogram); auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); - BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx, - d_histogram, rounding); - - for (size_t j = 0; j < new_histogram.size(); ++j) { - ASSERT_EQ(((Gradient)new_histogram[j]).GetGrad(), - ((Gradient)histogram[j]).GetGrad()); - ASSERT_EQ(((Gradient)new_histogram[j]).GetHess(), - ((Gradient)histogram[j]).GetHess()); + BuildGradientHistogram(page->GetDeviceAccessor(0), + feature_groups.DeviceAccessor(0), + gpair.DeviceSpan(), ridx, d_new_histogram, + rounding); + + std::vector new_histogram_h(num_bins); + dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(), + num_bins * sizeof(Gradient), + cudaMemcpyDeviceToHost)); + for (size_t j = 0; j < new_histogram_h.size(); ++j) { + ASSERT_EQ(new_histogram_h[j].GetGrad(), histogram_h[j].GetGrad()); + ASSERT_EQ(new_histogram_h[j].GetHess(), histogram_h[j].GetHess()); } } { auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); gpair.SetDevice(0); - dh::device_vector baseline(kBins * kCols); - BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx, - dh::ToSpan(baseline), rounding); + + // Use a single feature group to compute the baseline. + FeatureGroups single_group(page->Cuts()); + + dh::device_vector baseline(num_bins); + BuildGradientHistogram(page->GetDeviceAccessor(0), + single_group.DeviceAccessor(0), + gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), + rounding); + + std::vector baseline_h(num_bins); + dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(), + num_bins * sizeof(Gradient), + cudaMemcpyDeviceToHost)); + for (size_t i = 0; i < baseline.size(); ++i) { - EXPECT_NEAR(((Gradient)baseline[i]).GetGrad(), ((Gradient)histogram[i]).GetGrad(), - ((Gradient)baseline[i]).GetGrad() * 1e-3); + EXPECT_NEAR(baseline_h[i].GetGrad(), histogram_h[i].GetGrad(), + baseline_h[i].GetGrad() * 1e-3); } } } } -TEST(Histogram, GPUDeterminstic) { - TestDeterminsticHistogram(); - TestDeterminsticHistogram(); +TEST(Histogram, GPUDeterministic) { + std::vector is_dense_array{false, true}; + std::vector shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; + for (bool is_dense : is_dense_array) { + for (int shm_size : shm_sizes) { + TestDeterministicHistogram(is_dense, shm_size); + TestDeterministicHistogram(is_dense, shm_size); + } + } } } // namespace tree } // namespace xgboost