Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Features into Groups to Compute Histograms in Shared Memory #5795

Merged
merged 25 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
547c4ce
Splitting features into groups to compute histograms in shared memory.
canonizer Jun 6, 2020
0413e50
Deriving max_group_bins from the available shared memory.
canonizer Jun 9, 2020
399ab04
Compute histograms using shared memory for all feature groups in one …
canonizer Jun 10, 2020
29abc5d
Tuning the number of blocks for histograms.
canonizer Jun 10, 2020
e3edd7f
Minor changes.
canonizer Jun 15, 2020
9835250
Fixed tests, removed commented code.
canonizer Jun 15, 2020
8e9ed73
Fixed style errors.
canonizer Jun 16, 2020
d2f7235
Fixed an error.
canonizer Jun 17, 2020
f9b2be9
Merge branch 'master' into split-histo
canonizer Jun 17, 2020
b544a1b
Fix launch kernel.
trivialfis Jun 18, 2020
72faa6a
Linter.
trivialfis Jun 18, 2020
8cda8c6
Merge branch 'master' into split-histo
canonizer Jun 18, 2020
3cc50b0
Refactoring.
canonizer Jul 1, 2020
9c6525c
Tests for feature groups.
canonizer Jul 1, 2020
0841d94
Documentation comments.
canonizer Jul 1, 2020
72a1c30
Fixed formatting errors.
canonizer Jul 1, 2020
85f92f9
Fixed formatting errors.
canonizer Jul 1, 2020
15a7b83
Fixed a bug.
canonizer Jul 1, 2020
7c35750
Fixed formatting errors.
canonizer Jul 1, 2020
ce3bc85
Merge branch 'master' into split-histo
canonizer Jul 1, 2020
263857f
Tests for feature groups.
canonizer Jul 2, 2020
6b6df42
Fixed formatting errors.
canonizer Jul 2, 2020
d64127c
Addressed review comments.
canonizer Jul 2, 2020
3e26676
Fixed style errors.
canonizer Jul 2, 2020
64fcbfd
Fixed an error.
canonizer Jul 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/tree/gpu_hist/feature_groups.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/

#include <xgboost/base.h>
#include <algorithm>
#include <vector>

#include "feature_groups.cuh"

#include "../../common/device_helpers.cuh"
#include "../../common/hist_util.h"

namespace xgboost {
namespace tree {

FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
size_t shm_size, size_t bin_size) {
// Only use a single feature group for sparse matrices.
bool single_group = !is_dense;
if (single_group) {
InitSingle(cuts);
return;
}

std::vector<int>& feature_segments_h = feature_segments.HostVector();
std::vector<int>& bin_segments_h = bin_segments.HostVector();
feature_segments_h.push_back(0);
bin_segments_h.push_back(0);

const std::vector<uint32_t>& cut_ptrs = cuts.Ptrs();
int max_shmem_bins = shm_size / bin_size;
max_group_bins = 0;

for (size_t i = 2; i < cut_ptrs.size(); ++i) {
int last_start = bin_segments_h.back();
if (cut_ptrs[i] - last_start > max_shmem_bins) {
feature_segments_h.push_back(i - 1);
bin_segments_h.push_back(cut_ptrs[i - 1]);
max_group_bins = std::max(max_group_bins,
bin_segments_h.back() - last_start);
}
}
feature_segments_h.push_back(cut_ptrs.size() - 1);
bin_segments_h.push_back(cut_ptrs.back());
max_group_bins = std::max(max_group_bins,
bin_segments_h.back() -
bin_segments_h[bin_segments_h.size() - 2]);
}

void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) {
std::vector<int>& feature_segments_h = feature_segments.HostVector();
feature_segments_h.push_back(0);
feature_segments_h.push_back(cuts.Ptrs().size() - 1);

std::vector<int>& bin_segments_h = bin_segments.HostVector();
bin_segments_h.push_back(0);
bin_segments_h.push_back(cuts.TotalBins());

max_group_bins = cuts.TotalBins();
}

} // namespace tree
} // namespace xgboost
119 changes: 119 additions & 0 deletions src/tree/gpu_hist/feature_groups.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef FEATURE_GROUPS_CUH_
#define FEATURE_GROUPS_CUH_

#include <xgboost/host_device_vector.h>
#include <xgboost/span.h>

namespace xgboost {

// Forward declarations.
namespace common {
class HistogramCuts;
} // namespace common

namespace tree {

/** \brief FeatureGroup is a feature group. It is defined by a range of
consecutive feature indices, and also contains a range of all bin indices
associated with those features. */
struct FeatureGroup {
__host__ __device__ FeatureGroup(int start_feature_, int num_features_,
int start_bin_, int num_bins_) :
start_feature(start_feature_), num_features(num_features_),
start_bin(start_bin_), num_bins(num_bins_) {}
/** The first feature of the group. */
int start_feature;
/** The number of features in the group. */
int num_features;
/** The first bin in the group. */
int start_bin;
/** The number of bins in the group. */
int num_bins;
};

/** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */
struct FeatureGroupsAccessor {
FeatureGroupsAccessor(common::Span<const int> feature_segments_,
common::Span<const int> bin_segments_, int max_group_bins_) :
feature_segments(feature_segments_), bin_segments(bin_segments_),
max_group_bins(max_group_bins_) {}

common::Span<const int> feature_segments;
common::Span<const int> bin_segments;
int max_group_bins;

/** \brief Gets the number of feature groups. */
__host__ __device__ int NumGroups() const {
return feature_segments.size() - 1;
}

/** \brief Gets the information about a feature group with index i. */
__host__ __device__ FeatureGroup operator[](int i) const {
return {feature_segments[i], feature_segments[i + 1] - feature_segments[i],
bin_segments[i], bin_segments[i + 1] - bin_segments[i]};
}
};

/** \brief FeatureGroups contains information that defines a split of features
into groups. Bins of a single feature group typically fit into shared
memory, so the histogram for the features of a single group can be computed
faster.

\notes Known limitations:

- splitting features into groups currently works only for dense matrices,
where it is easy to get a feature value in a row by its index; for sparse
matrices, the structure contains only a single group containing all
features;

- if a single feature requires more bins than fit into shared memory, the
histogram is computed in global memory even if there are multiple feature
groups; note that this is unlikely to occur in practice, as the default
number of bins per feature is 256, whereas a thread block with 48 KiB
shared memory can contain 3072 bins if each gradient sum component is a
64-bit floating-point value (double)
*/
struct FeatureGroups {
/** Group cuts for features. Size equals to (number of groups + 1). */
HostDeviceVector<int> feature_segments;
/** Group cuts for bins. Size equals to (number of groups + 1) */
HostDeviceVector<int> bin_segments;
/** Maximum number of bins in a group. Useful to compute the amount of dynamic
shared memory when launching a kernel. */
int max_group_bins;

/** Creates feature groups by splitting features into groups.
\param cuts Histogram cuts that given the number of bins per feature.
\param is_dense Whether the data matrix is dense.
\param shm_size Available size of shared memory per thread block (in
bytes) used to compute feature groups.
\param bin_size Size of a single bin of the histogram. */
FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
size_t shm_size, size_t bin_size);

/** Creates a single feature group containing all features and bins.
\notes This is used as a fallback for sparse matrices, and is also useful
for testing.
*/
explicit FeatureGroups(const common::HistogramCuts& cuts) {
InitSingle(cuts);
}

FeatureGroupsAccessor DeviceAccessor(int device) const {
feature_segments.SetDevice(device);
bin_segments.SetDevice(device);
return {feature_segments.ConstDeviceSpan(), bin_segments.ConstDeviceSpan(),
max_group_bins};
}

private:
void InitSingle(const common::HistogramCuts& cuts);
};

} // namespace tree
} // namespace xgboost

#endif // FEATURE_GROUPS_CUH_
57 changes: 41 additions & 16 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,26 @@ template GradientPair CreateRoundingFactor(common::Span<GradientPair const> gpai

template <typename GradientSumT>
__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
FeatureGroupsAccessor feature_groups,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
size_t n_elements,
GradientSumT const rounding,
bool use_shared_memory_histograms) {
using T = typename GradientSumT::ValueT;
extern __shared__ char smem[];
FeatureGroup group = feature_groups[blockIdx.y];
GradientSumT* smem_arr = reinterpret_cast<GradientSumT*>(smem); // NOLINT
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, matrix.NumBins(), GradientSumT());
dh::BlockFill(smem_arr, group.num_bins, GradientSumT());
__syncthreads();
}
int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride;
size_t n_elements = feature_stride * d_ridx.size();
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / matrix.row_stride];
int gidx =
matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride];
int ridx = d_ridx[idx / feature_stride];
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature +
idx % feature_stride];
if (gidx != matrix.NumBins()) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), d_gpair[ridx].GetGrad()),
Expand All @@ -127,26 +130,30 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT* atomic_add_ptr =
use_shared_memory_histograms ? smem_arr : d_node_hist;
use_shared_memory_histograms ? smem_arr : d_node_hist;
gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx;
dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated);
}
}

if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(static_cast<size_t>(0), matrix.NumBins())) {
GradientSumT truncated {
TruncateWithRoundingFactor<T>(rounding.GetGrad(), smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(), smem_arr[i].GetHess()),
for (auto i : dh::BlockStrideRange(0, group.num_bins)) {
GradientSumT truncated{
TruncateWithRoundingFactor<T>(rounding.GetGrad(),
smem_arr[i].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.GetHess(),
smem_arr[i].GetHess()),
};
dh::AtomicAddGpair(d_node_hist + i, truncated);
dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated);
}
}
}

template <typename GradientSumT>
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientSumT> histogram,
Expand All @@ -155,7 +162,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
int max_shared_memory = dh::MaxSharedMemoryOptin(device);
size_t smem_size = sizeof(GradientSumT) * matrix.NumBins();
size_t smem_size = sizeof(GradientSumT) * feature_groups.max_group_bins;
bool shared = smem_size <= max_shared_memory;
smem_size = shared ? smem_size : 0;

Expand All @@ -169,29 +176,47 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,

// determine the launch configuration
unsigned block_threads = shared ? 1024 : 256;
int num_groups = feature_groups.NumGroups();
int n_mps = 0;
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor
(&n_blocks_per_mp, kernel, block_threads, smem_size));
unsigned grid_size = n_blocks_per_mp * n_mps;

auto n_elements = d_ridx.size() * matrix.row_stride;
dh::LaunchKernel {grid_size, block_threads, smem_size} (
kernel, matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
rounding, shared);
// TODO(canonizer): This is really a hack, find a better way to distribute the
// data among thread blocks.
// The intention is to generate enough thread blocks to fill the GPU, but
// avoid having too many thread blocks, as this is less efficient when the
// number of rows is low. At least one thread block per feature group is
// required.
// The number of thread blocks:
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
// - for num_groups_threshold <= num_groups <= num_groups_threshold * grid_size,
// around grid_size * num_groups_threshold
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
int num_groups_threshold = 4;
grid_size = common::DivRoundUp(grid_size,
common::DivRoundUp(num_groups, num_groups_threshold));

dh::LaunchKernel {dim3(grid_size, num_groups), block_threads, smem_size} (
kernel,
matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding,
shared);
dh::safe_cuda(cudaGetLastError());
}

template void BuildGradientHistogram<GradientPair>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
GradientPair rounding);

template void BuildGradientHistogram<GradientPairPrecise>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram,
Expand Down
4 changes: 4 additions & 0 deletions src/tree/gpu_hist/histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#ifndef HISTOGRAM_CUH_
#define HISTOGRAM_CUH_
#include <thrust/transform.h>

#include "feature_groups.cuh"

#include "../../data/ellpack_page.cuh"

namespace xgboost {
Expand All @@ -19,6 +22,7 @@ DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x)

template <typename GradientSumT>
void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientSumT> histogram,
Expand Down
11 changes: 9 additions & 2 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "param.h"
#include "updater_gpu_common.cuh"
#include "constraints.cuh"
#include "gpu_hist/feature_groups.cuh"
#include "gpu_hist/gradient_based_sampler.cuh"
#include "gpu_hist/row_partitioner.cuh"
#include "gpu_hist/histogram.cuh"
Expand Down Expand Up @@ -203,6 +204,8 @@ struct GPUHistMakerDevice {

std::unique_ptr<GradientBasedSampler> sampler;

std::unique_ptr<FeatureGroups> feature_groups;

GPUHistMakerDevice(int _device_id,
EllpackPageImpl* _page,
bst_uint _n_rows,
Expand All @@ -229,6 +232,9 @@ struct GPUHistMakerDevice {
// Init histogram
hist.Init(device_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
feature_groups.reset(new FeatureGroups(
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id),
sizeof(GradientSumT)));
}

~GPUHistMakerDevice() { // NOLINT
Expand Down Expand Up @@ -372,8 +378,9 @@ struct GPUHistMakerDevice {
hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(page->GetDeviceAccessor(device_id), gpair, d_ridx, d_node_hist,
histogram_rounding);
BuildGradientHistogram(page->GetDeviceAccessor(device_id),
feature_groups->DeviceAccessor(device_id), gpair,
d_ridx, d_node_hist, histogram_rounding);
}

void SubtractionTrick(int nidx_parent, int nidx_histogram,
Expand Down
Loading