Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimisations for wide datasets and deep trees #165

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ __global__ static void __launch_bounds__(BLOCK_THREADS)
int scan_node_idx = batch.node_idx_begin + j;
int parent = BinaryTree::Parent(scan_node_idx);
// Exit if we didn't compute this histogram
if (!ComputeHistogramBin(scan_node_idx, node_sums, histogram.ContainsNode(parent))) return;
if (i >= n_features || scan_node_idx >= batch.node_idx_end) return;
if (node_sums[{scan_node_idx, output}].hess <= 0.0) return;
if (!ComputeHistogramBin(scan_node_idx, node_sums, histogram.ContainsNode(parent))) return;

const int feature_idx = i;
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature_idx);
Expand Down Expand Up @@ -373,6 +374,8 @@ __global__ static void __launch_bounds__(BLOCK_THREADS)
{
// using one block per (level) node to have blockwise reductions
int node_id = batch.node_idx_begin + blockIdx.x;
// Early exit if this node has no samples
if (vectorised_load(&node_sums[{node_id, 0}]).hess <= 0) return;

typedef cub::BlockReduce<GainFeaturePair, BLOCK_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
Expand Down Expand Up @@ -655,8 +658,11 @@ struct TreeBuilder {
max_batch_size = max_histogram_nodes;
}

template <typename TYPE>
void UpdatePositions(Tree& tree, legate::AccessorRO<TYPE, 3> X, legate::Rect<3> X_shape)
template <typename TYPE, typename ThrustPolicyT>
void UpdatePositions(Tree& tree,
legate::AccessorRO<TYPE, 3> X,
legate::Rect<3> X_shape,
ThrustPolicyT& policy)
{
auto tree_split_value_ptr = tree.split_value.ptr(0);
auto tree_feature_ptr = tree.feature.ptr(0);
Expand All @@ -676,6 +682,12 @@ struct TreeBuilder {
sorted_positions[idx] = cuda::std::make_tuple(pos, row);
});
CHECK_CUDA_STREAM(stream);

thrust::sort(
policy,
sorted_positions.ptr(0),
sorted_positions.ptr(num_rows),
[] __device__(auto a, auto b) { return cuda::std::get<0>(a) < cuda::std::get<0>(b); });
}

template <typename TYPE>
Expand All @@ -696,7 +708,7 @@ struct TreeBuilder {
// splitting the features to ensure better work distribution for large numbers of features
// while larger value also allow better caching of g & h,
// smaller values improve access of the split_proposals
const int features_per_warp = 64;
const int features_per_warp = 32;
const size_t blocks_y = (num_features + features_per_warp - 1) / features_per_warp;
dim3 grid_shape = dim3(blocks_x, blocks_y, 1);
fill_histogram_warp<TYPE, threads_per_block, features_per_warp>
Expand Down Expand Up @@ -738,7 +750,7 @@ struct TreeBuilder {
double alpha,
NodeBatch batch)
{
const int kBlockThreads = 256;
const int kBlockThreads = 512;
perform_best_split<T, kBlockThreads>
<<<batch.NodesInBatch(), kBlockThreads, 0, stream>>>(histogram,
num_features,
Expand Down Expand Up @@ -813,12 +825,6 @@ struct TreeBuilder {
sorted_positions.ptr(0) + num_rows}};
}

thrust::sort(
policy,
sorted_positions.ptr(0),
sorted_positions.ptr(num_rows),
[] __device__(auto a, auto b) { return cuda::std::get<0>(a) < cuda::std::get<0>(b); });

// Launch a kernel where each thread computes the range of instances for a batch using binary
// search
const int num_batches = (BinaryTree::NodesInLevel(depth) + max_batch_size - 1) / max_batch_size;
Expand Down Expand Up @@ -846,7 +852,8 @@ struct TreeBuilder {
sorted_positions_ptr + num_rows,
cuda::std::tuple(batch_end - 1, 0),
comp);
batches_ptr[batch_idx] = {batch_begin, batch_end, lower, upper};
batches_ptr[batch_idx] = {
cuda::std::get<0>(*lower), cuda::std::get<0>(*(upper - 1)) + 1, lower, upper};
});

std::vector<NodeBatch> result(num_batches);
Expand Down Expand Up @@ -937,7 +944,9 @@ struct build_tree_fn {
}
// Update position of entire level
// Don't bother updating positions for the last level
if (depth < max_depth - 1) { builder.UpdatePositions(tree, X_accessor, X_shape); }
if (depth < max_depth - 1) {
builder.UpdatePositions(tree, X_accessor, X_shape, thrust_exec_policy);
}
}

tree.WriteTreeOutput(context, thrust_exec_policy, quantiser);
Expand Down
Loading