Skip to content

Commit

Permalink
Merge pull request #128 from RAMitchell/memory
Browse files Browse the repository at this point in the history
Batch node construction in deeper trees
  • Loading branch information
mfoerste4 authored Aug 6, 2024
2 parents 0c12b06 + a06aa21 commit b0ec587
Show file tree
Hide file tree
Showing 4 changed files with 455 additions and 200 deletions.
16 changes: 16 additions & 0 deletions legateboost/test/models/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def test_improving_with_depth(num_outputs):
assert metrics[-1] < metrics[0]


def test_max_depth():
# we should be able to run deep trees with OOM
max_depth = 20
X = cn.random.random((2, 1))
y = cn.array([500.0, 500.0])
model = lb.LBRegressor(
init=None,
base_models=(lb.models.Tree(max_depth=max_depth),),
learning_rate=1.0,
n_estimators=1,
random_state=0,
)

model.fit(X, y)


def test_alpha():
X = cn.random.random((2, 1))
y = cn.array([500.0, 500.0])
Expand Down
205 changes: 143 additions & 62 deletions src/models/tree/build_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,22 @@
#include <random>

namespace legateboost {

namespace {

struct NodeBatch {
int32_t node_idx_begin;
int32_t node_idx_end;
std::tuple<int32_t, int32_t>* instances_begin;
std::tuple<int32_t, int32_t>* instances_end;
auto begin() const { return instances_begin; }
auto end() const { return instances_end; }
__host__ __device__ std::size_t InstancesInBatch() const
{
return instances_end - instances_begin;
}
__host__ __device__ std::size_t NodesInBatch() const { return node_idx_end - node_idx_begin; }
};

struct Tree {
Tree(int max_nodes, int num_outputs) : num_outputs(num_outputs)
{
Expand Down Expand Up @@ -158,64 +172,72 @@ struct TreeBuilder {
int32_t num_features,
int32_t num_outputs,
int32_t max_nodes,
int32_t max_depth,
SparseSplitProposals<T> split_proposals)
: num_rows(num_rows),
num_features(num_features),
num_outputs(num_outputs),
max_nodes(max_nodes),
split_proposals(split_proposals),
histogram_buffer(
legate::create_buffer<GPair, 3>({max_nodes, split_proposals.histogram_size, num_outputs})),
positions(num_rows, 0)
split_proposals(split_proposals)
{
auto ptr = histogram_buffer.ptr({0, 0, 0});
std::fill(ptr, ptr + max_nodes * split_proposals.histogram_size * num_outputs, GPair{0.0, 0.0});
sorted_positions = legate::create_buffer<std::tuple<int32_t, int32_t>>(num_rows);
for (auto i = 0; i < num_rows; ++i) { sorted_positions[i] = {0, i}; }
const std::size_t max_bytes = std::pow(10, 9); // 1 GB
const std::size_t bytes_per_node = num_outputs * split_proposals.histogram_size * sizeof(GPair);
const std::size_t max_histogram_nodes = std::max(1ul, max_bytes / bytes_per_node);
int depth = 0;
while (BinaryTree::LevelEnd(depth + 1) <= max_histogram_nodes && depth <= max_depth) depth++;
histogram = Histogram(BinaryTree::LevelBegin(0),
BinaryTree::LevelEnd(depth),
num_outputs,
split_proposals.histogram_size);
max_batch_size = max_histogram_nodes;
}
~TreeBuilder() { histogram_buffer.destroy(); }
template <typename TYPE>
void ComputeHistogram(int depth,
void ComputeHistogram(Histogram histogram,
legate::TaskContext context,
Tree& tree,
legate::AccessorRO<TYPE, 3> X,
legate::Rect<3> X_shape,
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h)
legate::AccessorRO<double, 3> h,
NodeBatch batch)
{
// Build the histogram
for (int64_t i = X_shape.lo[0]; i <= X_shape.hi[0]; i++) {
auto index_local = i - X_shape.lo[0];
auto position = positions[index_local];
bool compute = ComputeHistogramBin(position, depth, tree.hessian);
for (auto [position, index_local] : batch) {
auto index_global = index_local + X_shape.lo[0];
bool compute = ComputeHistogramBin(
position, tree.hessian, histogram.ContainsNode(BinaryTree::Parent(position)));
if (position < 0 || !compute) continue;
for (int64_t j = 0; j < num_features; j++) {
auto x_value = X[{i, j, 0}];
auto x_value = X[{index_global, j, 0}];
int bin_idx = split_proposals.FindBin(x_value, j);

if (bin_idx != SparseSplitProposals<T>::NOT_FOUND) {
for (int64_t k = 0; k < num_outputs; ++k) {
histogram_buffer[{position, bin_idx, k}] += GPair{g[{i, 0, k}], h[{i, 0, k}]};
histogram[{position, k, bin_idx}] +=
GPair{g[{index_global, 0, k}], h[{index_global, 0, k}]};
}
}
}
}

SumAllReduce(
context,
reinterpret_cast<double*>(histogram_buffer.ptr({BinaryTree::LevelBegin(depth), 0, 0})),
BinaryTree::NodesInLevel(depth) * split_proposals.histogram_size * num_outputs * 2);
this->Scan(depth, tree);
SumAllReduce(context,
reinterpret_cast<double*>(histogram.Ptr(batch.node_idx_begin)),
batch.NodesInBatch() * num_outputs * split_proposals.histogram_size * 2);
this->Scan(histogram, batch, tree);
}

void Scan(int depth, Tree& tree)
void Scan(Histogram histogram, NodeBatch batch, Tree& tree)
{
auto scan_node_histogram = [&](int node_idx) {
for (int feature = 0; feature < num_features; feature++) {
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int output = 0; output < num_outputs; output++) {
GPair sum = {0.0, 0.0};
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
sum += histogram_buffer[{node_idx, bin_idx, output}];
histogram_buffer[{node_idx, bin_idx, output}] = sum;
sum += histogram[{node_idx, output, bin_idx}];
histogram[{node_idx, output, bin_idx}] = sum;
}
}
}
Expand All @@ -227,31 +249,37 @@ struct TreeBuilder {
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int output = 0; output < num_outputs; output++) {
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
auto scanned_sum = histogram_buffer[{scanned_node_idx, bin_idx, output}];
auto parent_sum = histogram_buffer[{parent_node_idx, bin_idx, output}];
histogram_buffer[{subtract_node_idx, bin_idx, output}] = parent_sum - scanned_sum;
auto scanned_sum = histogram[{scanned_node_idx, output, bin_idx}];
auto parent_sum = histogram[{parent_node_idx, output, bin_idx}];
histogram[{subtract_node_idx, output, bin_idx}] = parent_sum - scanned_sum;
}
}
}
};

if (depth == 0) {
if (batch.node_idx_begin == 0 && batch.node_idx_end == 1) {
scan_node_histogram(0);
return;
}

for (int parent_id = BinaryTree::LevelBegin(depth - 1);
parent_id < BinaryTree::LevelBegin(depth - 1) + BinaryTree::NodesInLevel(depth - 1);
parent_id++) {
auto [histogram_node_idx, subtract_node_idx] = SelectHistogramNode(parent_id, tree.hessian);
scan_node_histogram(histogram_node_idx);
subtract_node_histogram(subtract_node_idx, histogram_node_idx, parent_id);
for (int node_idx = batch.node_idx_begin; node_idx < batch.node_idx_end; node_idx++) {
auto parent = BinaryTree::Parent(node_idx);
if (!ComputeHistogramBin(node_idx, tree.hessian, histogram.ContainsNode(parent))) continue;
scan_node_histogram(node_idx);
// This node has no sibling we are finished
if (node_idx == 0) continue;

auto sibling_node_idx = BinaryTree::Sibling(node_idx);
// The sibling did not compute a histogram
// Do the subtraction trick using the histogram we just computed in the previous step
if (!ComputeHistogramBin(sibling_node_idx, tree.hessian, histogram.ContainsNode(parent))) {
subtract_node_histogram(sibling_node_idx, node_idx, parent);
}
}
}
void PerformBestSplit(int depth, Tree& tree, double alpha)
void PerformBestSplit(Tree& tree, Histogram histogram, double alpha, NodeBatch batch)
{
for (int node_id = BinaryTree::LevelBegin(depth); node_id < BinaryTree::LevelBegin(depth + 1);
node_id++) {
for (int node_id = batch.node_idx_begin; node_id < batch.node_idx_end; node_id++) {
double best_gain = 0;
int best_feature = -1;
int best_bin = -1;
Expand All @@ -260,7 +288,7 @@ struct TreeBuilder {
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
double gain = 0;
for (int output = 0; output < num_outputs; ++output) {
auto [G_L, H_L] = histogram_buffer[{node_id, bin_idx, output}];
auto [G_L, H_L] = histogram[{node_id, output, bin_idx}];
auto G = tree.gradient[{node_id, output}];
auto H = tree.hessian[{node_id, output}];
auto G_R = G - G_L;
Expand All @@ -284,7 +312,7 @@ struct TreeBuilder {
std::vector<double> hessian_left(num_outputs);
std::vector<double> hessian_right(num_outputs);
for (int output = 0; output < num_outputs; ++output) {
auto [G_L, H_L] = histogram_buffer[{node_id, best_bin, output}];
auto [G_L, H_L] = histogram[{node_id, output, best_bin}];
auto G = tree.gradient[{node_id, output}];
auto H = tree.hessian[{node_id, output}];
auto G_R = G - G_L;
Expand All @@ -311,26 +339,71 @@ struct TreeBuilder {
}
}
template <typename TYPE>
void UpdatePositions(int depth,
Tree& tree,
legate::AccessorRO<TYPE, 3> X,
legate::Rect<3> X_shape)
void UpdatePositions(Tree& tree, legate::AccessorRO<TYPE, 3> X, legate::Rect<3> X_shape)
{
if (depth == 0) return;
// Update the positions
for (int64_t i = X_shape.lo[0]; i <= X_shape.hi[0]; i++) {
auto index_local = i - X_shape.lo[0];
int& pos = positions[index_local];
if (pos < 0 || tree.IsLeaf(pos)) {
pos = -1;
for (int i = 0; i < num_rows; i++) {
auto [pos, index_local] = sorted_positions[i];
if (pos < 0 || pos >= max_nodes || tree.IsLeaf(pos)) {
sorted_positions[i] = {-1, index_local};
continue;
}
auto x = X[{i, tree.feature[pos], 0}];
bool left = x <= tree.split_value[pos];
pos = left ? BinaryTree::LeftChild(pos) : BinaryTree::RightChild(pos);
auto x = X[{X_shape.lo[0] + index_local, tree.feature[pos], 0}];
bool left = x <= tree.split_value[pos];
pos = left ? BinaryTree::LeftChild(pos) : BinaryTree::RightChild(pos);
sorted_positions[i] = {pos, index_local};
}
}

// Create a new histogram for this batch if we need to
// Destroy the old one
Histogram GetHistogram(NodeBatch batch)
{
if (histogram.ContainsBatch(batch.node_idx_begin, batch.node_idx_end)) { return histogram; }

histogram.Destroy();
histogram = Histogram(
batch.node_idx_begin, batch.node_idx_end, num_outputs, split_proposals.histogram_size);
return histogram;
}

std::vector<NodeBatch> PrepareBatches(int depth)
{
// Shortcut if we have 1 batch
if (BinaryTree::NodesInLevel(depth) <= max_batch_size) {
// All instances are in batch
return {NodeBatch{BinaryTree::LevelBegin(depth),
BinaryTree::LevelEnd(depth),
sorted_positions.ptr(0),
sorted_positions.ptr(0) + num_rows}};
}

std::sort(sorted_positions.ptr(0),
sorted_positions.ptr(0) + num_rows,
[] __device__(auto a, auto b) { return std::get<0>(a) < std::get<0>(b); });

const int num_batches = (BinaryTree::NodesInLevel(depth) + max_batch_size - 1) / max_batch_size;
std::vector<NodeBatch> batches(num_batches);

for (auto batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
int batch_begin = BinaryTree::LevelBegin(depth) + batch_idx * max_batch_size;
int batch_end = std::min(batch_begin + max_batch_size, BinaryTree::LevelEnd(depth));
auto comp = [] __device__(auto a, auto b) { return std::get<0>(a) < std::get<0>(b); };

auto lower = std::lower_bound(sorted_positions.ptr(0),
sorted_positions.ptr(0) + num_rows,
std::tuple(batch_begin, 0),
comp);
auto upper = std::upper_bound(
lower, sorted_positions.ptr(0) + num_rows, std::tuple(batch_end - 1, 0), comp);
batches[batch_idx] = {batch_begin, batch_end, lower, upper};
}
batches.erase(std::remove_if(batches.begin(),
batches.end(),
[](const NodeBatch& b) { return b.InstancesInBatch() == 0; }),
batches.end());
return batches;
}
void InitialiseRoot(legate::TaskContext context,
Tree& tree,
legate::AccessorRO<double, 3> g_accessor,
Expand All @@ -353,13 +426,14 @@ struct TreeBuilder {
}
}

std::vector<int32_t> positions;
legate::Buffer<std::tuple<int32_t, int32_t>, 1> sorted_positions;
const int32_t num_rows;
const int32_t num_features;
const int32_t num_outputs;
const int32_t max_nodes;
int max_batch_size;
SparseSplitProposals<T> split_proposals;
legate::Buffer<GPair, 3> histogram_buffer;
Histogram histogram;
};

struct build_tree_fn {
Expand Down Expand Up @@ -391,17 +465,24 @@ struct build_tree_fn {
SelectSplitSamples(context, X_accessor, X_shape, split_samples, seed, dataset_rows);

// Begin building the tree
TreeBuilder<T> tree_builder(num_rows, num_features, num_outputs, max_nodes, split_proposals);
TreeBuilder<T> builder(
num_rows, num_features, num_outputs, max_nodes, max_depth, split_proposals);

tree_builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha);
for (int64_t depth = 0; depth < max_depth; ++depth) {
tree_builder.UpdatePositions(depth, tree, X_accessor, X_shape);
builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha);
for (int depth = 0; depth < max_depth; ++depth) {
auto batches = builder.PrepareBatches(depth);
for (auto batch : batches) {
auto histogram = builder.GetHistogram(batch);

tree_builder.ComputeHistogram(
depth, context, tree, X_accessor, X_shape, g_accessor, h_accessor);
tree_builder.PerformBestSplit(depth, tree, alpha);
}
builder.ComputeHistogram(
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch);

builder.PerformBestSplit(tree, histogram, alpha, batch);
}
// Update position of entire level
// Don't bother updating positions for the last level
if (depth < max_depth - 1) { builder.UpdatePositions(tree, X_accessor, X_shape); }
}
WriteTreeOutput(context, tree);
}
};
Expand Down
Loading

0 comments on commit b0ec587

Please sign in to comment.