Skip to content

Commit

Permalink
Split Features into Groups to Compute Histograms in Shared Memory (#5795
Browse files Browse the repository at this point in the history
)
  • Loading branch information
canonizer authored Jul 7, 2020
1 parent 93c44a9 commit ac3f0e7
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 42 deletions.
64 changes: 64 additions & 0 deletions src/tree/gpu_hist/feature_groups.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/

#include <xgboost/base.h>
#include <algorithm>
#include <vector>

#include "feature_groups.cuh"

#include "../../common/device_helpers.cuh"
#include "../../common/hist_util.h"

namespace xgboost {
namespace tree {

FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
size_t shm_size, size_t bin_size) {
// Only use a single feature group for sparse matrices.
bool single_group = !is_dense;
if (single_group) {
InitSingle(cuts);
return;
}

std::vector<int>& feature_segments_h = feature_segments.HostVector();
std::vector<int>& bin_segments_h = bin_segments.HostVector();
feature_segments_h.push_back(0);
bin_segments_h.push_back(0);

const std::vector<uint32_t>& cut_ptrs = cuts.Ptrs();
int max_shmem_bins = shm_size / bin_size;
max_group_bins = 0;

for (size_t i = 2; i < cut_ptrs.size(); ++i) {
int last_start = bin_segments_h.back();
if (cut_ptrs[i] - last_start > max_shmem_bins) {
feature_segments_h.push_back(i - 1);
bin_segments_h.push_back(cut_ptrs[i - 1]);
max_group_bins = std::max(max_group_bins,
bin_segments_h.back() - last_start);
}
}
feature_segments_h.push_back(cut_ptrs.size() - 1);
bin_segments_h.push_back(cut_ptrs.back());
max_group_bins = std::max(max_group_bins,
bin_segments_h.back() -
bin_segments_h[bin_segments_h.size() - 2]);
}

void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) {
std::vector<int>& feature_segments_h = feature_segments.HostVector();
feature_segments_h.push_back(0);
feature_segments_h.push_back(cuts.Ptrs().size() - 1);

std::vector<int>& bin_segments_h = bin_segments.HostVector();
bin_segments_h.push_back(0);
bin_segments_h.push_back(cuts.TotalBins());

max_group_bins = cuts.TotalBins();
}

} // namespace tree
} // namespace xgboost
119 changes: 119 additions & 0 deletions src/tree/gpu_hist/feature_groups.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef FEATURE_GROUPS_CUH_
#define FEATURE_GROUPS_CUH_

#include <xgboost/host_device_vector.h>
#include <xgboost/span.h>

namespace xgboost {

// Forward declarations.
namespace common {
class HistogramCuts;
} // namespace common

namespace tree {

/** \brief FeatureGroup is a feature group. It is defined by a range of
consecutive feature indices, and also contains a range of all bin indices
associated with those features. */
struct FeatureGroup {
__host__ __device__ FeatureGroup(int start_feature_, int num_features_,
int start_bin_, int num_bins_) :
start_feature(start_feature_), num_features(num_features_),
start_bin(start_bin_), num_bins(num_bins_) {}
/** The first feature of the group. */
int start_feature;
/** The number of features in the group. */
int num_features;
/** The first bin in the group. */
int start_bin;
/** The number of bins in the group. */
int num_bins;
};

/** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */
struct FeatureGroupsAccessor {
FeatureGroupsAccessor(common::Span<const int> feature_segments_,
common::Span<const int> bin_segments_, int max_group_bins_) :
feature_segments(feature_segments_), bin_segments(bin_segments_),
max_group_bins(max_group_bins_) {}

common::Span<const int> feature_segments;
common::Span<const int> bin_segments;
int max_group_bins;

/** \brief Gets the number of feature groups. */
__host__ __device__ int NumGroups() const {
return feature_segments.size() - 1;
}

/** \brief Gets the information about a feature group with index i. */
__host__ __device__ FeatureGroup operator[](int i) const {
return {feature_segments[i], feature_segments[i + 1] - feature_segments[i],
bin_segments[i], bin_segments[i + 1] - bin_segments[i]};
}
};

/** \brief FeatureGroups contains information that defines a split of features
into groups. Bins of a single feature group typically fit into shared
memory, so the histogram for the features of a single group can be computed
faster.
\notes Known limitations:
- splitting features into groups currently works only for dense matrices,
where it is easy to get a feature value in a row by its index; for sparse
matrices, the structure contains only a single group containing all
features;
- if a single feature requires more bins than fit into shared memory, the
histogram is computed in global memory even if there are multiple feature
groups; note that this is unlikely to occur in practice, as the default
number of bins per feature is 256, whereas a thread block with 48 KiB
shared memory can contain 3072 bins if each gradient sum component is a
64-bit floating-point value (double)
*/
struct FeatureGroups {
/** Group cuts for features. Size equals to (number of groups + 1). */
HostDeviceVector<int> feature_segments;
/** Group cuts for bins. Size equals to (number of groups + 1) */
HostDeviceVector<int> bin_segments;
/** Maximum number of bins in a group. Useful to compute the amount of dynamic
shared memory when launching a kernel. */
int max_group_bins;

/** Creates feature groups by splitting features into groups.
\param cuts Histogram cuts that given the number of bins per feature.
\param is_dense Whether the data matrix is dense.
\param shm_size Available size of shared memory per thread block (in
bytes) used to compute feature groups.
\param bin_size Size of a single bin of the histogram. */
FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
size_t shm_size, size_t bin_size);

/** Creates a single feature group containing all features and bins.
\notes This is used as a fallback for sparse matrices, and is also useful
for testing.
*/
explicit FeatureGroups(const common::HistogramCuts& cuts) {
InitSingle(cuts);
}

FeatureGroupsAccessor DeviceAccessor(int device) const {
feature_segments.SetDevice(device);
bin_segments.SetDevice(device);
return {feature_segments.ConstDeviceSpan(), bin_segments.ConstDeviceSpan(),
max_group_bins};
}

private:
void InitSingle(const common::HistogramCuts& cuts);
};

} // namespace tree
} // namespace xgboost

#endif // FEATURE_GROUPS_CUH_
57 changes: 41 additions & 16 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,26 @@ template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpai

template <typename GradientSumT>
__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
FeatureGroupsAccessor feature_groups,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
size_t n_elements,
GradientSumT const rounding,
bool use_shared_memory_histograms) {
using T = typename GradientSumT::ValueT;
extern __shared__ char smem[];
FeatureGroup group = feature_groups[blockIdx.y];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.NumBins(), GradientSumT());
dh::BlockFill(smem_arr, group.num_bins, GradientSumT());
__syncthreads();
}
int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride;
size_t n_elements = feature_stride * d_ridx.size();
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
int ridx = d_ridx[idx / feature_stride];
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature +
idx % feature_stride];
if (gidx != matrix.NumBins()) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
Expand All @@ -127,26 +130,30 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
use_shared_memory_histograms ? smem_arr : d_node_hist;
gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx;
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
}
}

if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.NumBins())) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), smem_arr[i].GetHess()),
for (auto i : dh::BlockStrideRange(0, group.num_bins)) {
GradientSumT truncated{
TruncateWithRoundingFactor<T>(rounding.GetGrad(),
smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(),
smem_arr[i].GetHess()),
};
dh::AtomicAddGpair(d_node_hist + i, truncated);
dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated);
}
}
}

template <typename GradientSumT>
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientSumT> histogram,
Expand All @@ -155,7 +162,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
int max_shared_memory = dh::MaxSharedMemoryOptin(device);
size_t smem_size = sizeof(GradientSumT) * matrix.NumBins();
size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins;
bool shared = smem_size <= max_shared_memory;
smem_size = shared ? smem_size : 0;

Expand All @@ -169,29 +176,47 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,

// determine the launch configuration
unsigned block_threads = shared ? 1024 : 256;
int num_groups = feature_groups.NumGroups();
int n_mps = 0;
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor
(&n_blocks_per_mp, kernel, block_threads, smem_size));
unsigned grid_size = n_blocks_per_mp * n_mps;

auto n_elements = d_ridx.size() * matrix.row_stride;
dh::LaunchKernel {grid_size, block_threads, smem_size} (
kernel, matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
rounding, shared);
// TODO(canonizer): This is really a hack, find a better way to distribute the
// data among thread blocks.
// The intention is to generate enough thread blocks to fill the GPU, but
// avoid having too many thread blocks, as this is less efficient when the
// number of rows is low. At least one thread block per feature group is
// required.
// The number of thread blocks:
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
// - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size,
// around grid_size * num_groups_threshold
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
int num_groups_threshold = 4;
grid_size = common::DivRoundUp(grid_size,
common::DivRoundUp(num_groups, num_groups_threshold));

dh::LaunchKernel {dim3(grid_size, num_groups), block_threads, smem_size} (
kernel,
matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding,
shared);
dh::safe_cuda(cudaGetLastError());
}

template void BuildGradientHistogram<GradientPair>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
GradientPair rounding);

template void BuildGradientHistogram<GradientPairPrecise>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram,
Expand Down
4 changes: 4 additions & 0 deletions src/tree/gpu_hist/histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#ifndef HISTOGRAM_CUH_
#define HISTOGRAM_CUH_
#include <thrust/transform.h>

#include "feature_groups.cuh"

#include "../../data/ellpack_page.cuh"

namespace xgboost {
Expand All @@ -19,6 +22,7 @@ DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x)

template <typename GradientSumT>
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientSumT> histogram,
Expand Down
11 changes: 9 additions & 2 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"
#include "gpu_hist/feature_groups.cuh"
#include "gpu_hist/gradient_based_sampler.cuh"
#include "gpu_hist/row_partitioner.cuh"
#include "gpu_hist/histogram.cuh"
Expand Down Expand Up @@ -203,6 +204,8 @@ struct GPUHistMakerDevice {

std::unique_ptr<GradientBasedSampler> sampler;

std::unique_ptr<FeatureGroups> feature_groups;

GPUHistMakerDevice(int _device_id,
EllpackPageImpl* _page,
bst_uint _n_rows,
Expand All @@ -229,6 +232,9 @@ struct GPUHistMakerDevice {
// Init histogram
hist.Init(device_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
feature_groups.reset(new FeatureGroups(
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id),
sizeof(GradientSumT)));
}

~GPUHistMakerDevice() { // NOLINT
Expand Down Expand Up @@ -372,8 +378,9 @@ struct GPUHistMakerDevice {
hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(page->GetDeviceAccessor(device_id), gpair, d_ridx, d_node_hist,
histogram_rounding);
BuildGradientHistogram(page->GetDeviceAccessor(device_id),
feature_groups->DeviceAccessor(device_id), gpair,
d_ridx, d_node_hist, histogram_rounding);
}

void SubtractionTrick(int nidx_parent, int nidx_histogram,
Expand Down
Loading

0 comments on commit ac3f0e7

Please sign in to comment.