Skip to content

Commit

Permalink
Use sparse histograms (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell authored Jul 26, 2024
1 parent 0c4321c commit 157b427
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 189 deletions.
21 changes: 4 additions & 17 deletions legateboost/models/tree.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import math
from enum import IntEnum
from typing import Any

import cunumeric as cn
from legate.core import TaskTarget, get_legate_runtime, types

from ..library import user_context, user_lib
from ..utils import gather, get_store
from ..utils import get_store
from .base_model import BaseModel


Expand Down Expand Up @@ -51,11 +50,6 @@ def __eq__(self, other: object) -> bool:
eq.append(cn.all(self.hessian == other.hessian))
return all(eq)

def num_procs_to_use(self, num_rows: int) -> int:
min_rows_per_worker = 10
available_procs = len(get_legate_runtime().machine)
return min(available_procs, int(math.ceil(num_rows / min_rows_per_worker)))

def __init__(
self,
max_depth: int = 8,
Expand All @@ -72,14 +66,6 @@ def fit(
g: cn.ndarray,
h: cn.ndarray,
) -> "Tree":
# dont let legate create a future - make sure at least 2 sample rows
sample_rows_np = self.random_state.randint(
0, X.shape[0], max(2, self.split_samples)
)
split_proposals = gather(X, tuple(sample_rows_np))
split_proposals.sort(axis=0)
split_proposals = split_proposals.T

num_outputs = g.shape[1]

task = get_legate_runtime().create_auto_task(
Expand All @@ -95,15 +81,16 @@ def fit(
max_nodes = 2 ** (self.max_depth + 1)
task.add_scalar_arg(max_nodes, types.int32)
task.add_scalar_arg(self.alpha, types.float64)
task.add_scalar_arg(self.split_samples, types.int32)
task.add_scalar_arg(self.random_state.randint(0, 2**31), types.int32)
task.add_scalar_arg(X.shape[0], types.int64)

task.add_input(X_)
task.add_broadcast(X_, 1)
task.add_input(g_)
task.add_input(h_)
task.add_alignment(g_, h_)
task.add_alignment(g_, X_)
task.add_input(get_store(split_proposals))
task.add_broadcast(get_store(split_proposals))

# outputs
leaf_value = get_legate_runtime().create_store(
Expand Down
13 changes: 13 additions & 0 deletions legateboost/test/models/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ def test_determinism(max_depth):
check_determinism(lb.models.Tree(max_depth=max_depth))


def test_basic():
# tree loss can go to zero
X = cn.array([[0.0], [1.0]])
g = cn.array([[0.0], [-1.0]])
h = cn.array([[1.0], [1.0]])
model = (
lb.models.Tree(max_depth=1, alpha=0.0)
.set_random_state(np.random.RandomState(2))
.fit(X, g, h)
)
assert np.allclose(model.predict(X), np.array([[0.0], [1.0]]))


@pytest.mark.parametrize("num_outputs", [1, 5])
def test_improving_with_depth(num_outputs):
rs = cn.random.RandomState(0)
Expand Down
34 changes: 8 additions & 26 deletions src/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,19 @@ std::vector<legate::mapping::StoreMapping> LegateboostMapper::store_mappings(
const legate::mapping::Task& task, const std::vector<legate::mapping::StoreTarget>& options)
{
auto task_id = task.task_id();
switch (task_id) {
case GATHER:
case PREDICT: {
std::vector<legate::mapping::StoreMapping> mappings;
auto input_x = task.input(0);
// Enforce c-ordering for these tasks
std::set<LegateBoostOpCode> row_major_only = {BUILD_TREE};
std::vector<legate::mapping::StoreMapping> mappings;
if (row_major_only.count(static_cast<LegateBoostOpCode>(task_id))) {
for (auto input : task.inputs()) {
mappings.push_back(
legate::mapping::StoreMapping::default_mapping(input_x.data(), options.front()));
legate::mapping::StoreMapping::default_mapping(input.data(), options.front()));
mappings.back().policy().ordering.set_c_order();
mappings.back().policy().exact = true;
return std::move(mappings);
}
case BUILD_TREE: {
std::vector<legate::mapping::StoreMapping> mappings;
auto input_x = task.input(0);
mappings.push_back(
legate::mapping::StoreMapping::default_mapping(input_x.data(), options.front()));
mappings.back().policy().ordering.set_c_order();
mappings.back().policy().exact = true;
auto input_splits = task.input(3);
mappings.push_back(
legate::mapping::StoreMapping::default_mapping(input_splits.data(), options.front()));
mappings.back().policy().ordering.set_c_order();
mappings.back().policy().exact = true;
return std::move(mappings);
}
default: {
return {};
}
return mappings;
}
assert(false);
return {};
return mappings;
}

} // namespace legateboost
146 changes: 93 additions & 53 deletions src/models/tree/build_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "legateboost.h"
#include "../../cpp_utils/cpp_utils.h"
#include "build_tree.h"
#include <random>

namespace legateboost {

Expand Down Expand Up @@ -102,24 +103,73 @@ void WriteTreeOutput(legate::TaskContext context, const Tree& tree)
WriteOutput(context.output(4).data(), tree.hessian);
}

// Randomly sample split_samples rows from X
// Share the samples with all workers
// Remove any duplicates
// Return sparse matrix of split samples for each feature
template <typename T>
SparseSplitProposals<T> SelectSplitSamples(legate::TaskContext context,
legate::AccessorRO<T, 3> X,
legate::Rect<3> X_shape,
int split_samples,
int seed,
int64_t dataset_rows)
{
std::vector<int64_t> row_samples(split_samples);

std::default_random_engine eng(seed);
std::uniform_int_distribution<int64_t> dist(0, dataset_rows - 1);
std::transform(row_samples.begin(), row_samples.end(), row_samples.begin(), [&dist, &eng](int) {
return dist(eng);
});

int num_features = X_shape.hi[1] - X_shape.lo[1] + 1;
auto draft_proposals = legate::create_buffer<T, 2>({num_features, split_samples});
for (int i = 0; i < split_samples; i++) {
auto row = row_samples[i];
bool has_data = row >= X_shape.lo[0] && row <= X_shape.hi[0];
for (int j = 0; j < num_features; j++) {
draft_proposals[{j, i}] = has_data ? X[{row, j, 0}] : T(0);
}
}
SumAllReduce(context, draft_proposals.ptr({0, 0}), num_features * split_samples);

// Sort samples
std::vector<T> split_proposals_tmp;
split_proposals_tmp.reserve(num_features * split_samples);
auto row_pointers = legate::create_buffer<int32_t, 1>({num_features + 1});
row_pointers[0] = 0;
for (int j = 0; j < num_features; j++) {
auto ptr = draft_proposals.ptr({j, 0});
std::set<T> unique(ptr, ptr + split_samples);
row_pointers[j + 1] = row_pointers[j] + unique.size();
split_proposals_tmp.insert(split_proposals_tmp.end(), unique.begin(), unique.end());
}

auto split_proposals = legate::create_buffer<T, 1>(split_proposals_tmp.size());
std::copy(split_proposals_tmp.begin(), split_proposals_tmp.end(), split_proposals.ptr(0));
return SparseSplitProposals<T>(
split_proposals, row_pointers, num_features, split_proposals_tmp.size());
}

template <typename T>
struct TreeBuilder {
TreeBuilder(int32_t num_rows,
int32_t num_features,
int32_t num_outputs,
int32_t max_nodes,
int32_t samples_per_feature)
SparseSplitProposals<T> split_proposals)
: num_rows(num_rows),
num_features(num_features),
num_outputs(num_outputs),
max_nodes(max_nodes),
samples_per_feature(samples_per_feature),
histogram_buffer(legate::create_buffer<GPair, 4>(
{max_nodes, num_features, num_outputs, samples_per_feature})),
split_proposals(split_proposals),
histogram_buffer(
legate::create_buffer<GPair, 3>({max_nodes, split_proposals.histogram_size, num_outputs})),
positions(num_rows, 0)
{
auto ptr = histogram_buffer.ptr({0, 0, 0, 0});
std::fill(
ptr, ptr + max_nodes * num_features * num_outputs * samples_per_feature, GPair{0.0, 0.0});
auto ptr = histogram_buffer.ptr({0, 0, 0});
std::fill(ptr, ptr + max_nodes * split_proposals.histogram_size * num_outputs, GPair{0.0, 0.0});
}
~TreeBuilder() { histogram_buffer.destroy(); }
template <typename TYPE>
Expand All @@ -128,7 +178,6 @@ struct TreeBuilder {
Tree& tree,
legate::AccessorRO<TYPE, 3> X,
legate::Rect<3> X_shape,
legate::AccessorRO<TYPE, 2> split_proposal,
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h)
{
Expand All @@ -140,35 +189,33 @@ struct TreeBuilder {
if (position < 0 || !compute) continue;
for (int64_t j = 0; j < num_features; j++) {
auto x_value = X[{i, j, 0}];
int bin_idx =
std::lower_bound(
split_proposal.ptr({j, 0}), split_proposal.ptr({j, samples_per_feature}), x_value) -
split_proposal.ptr({j, 0});
int bin_idx = split_proposals.FindBin(x_value, j);

if (bin_idx < samples_per_feature) {
if (bin_idx != SparseSplitProposals<T>::NOT_FOUND) {
for (int64_t k = 0; k < num_outputs; ++k) {
histogram_buffer[{position, j, k, bin_idx}] += GPair{g[{i, 0, k}], h[{i, 0, k}]};
histogram_buffer[{position, bin_idx, k}] += GPair{g[{i, 0, k}], h[{i, 0, k}]};
}
}
}
}

SumAllReduce(
context,
reinterpret_cast<double*>(histogram_buffer.ptr({BinaryTree::LevelBegin(depth), 0, 0, 0})),
BinaryTree::NodesInLevel(depth) * num_features * samples_per_feature * num_outputs * 2);
reinterpret_cast<double*>(histogram_buffer.ptr({BinaryTree::LevelBegin(depth), 0, 0})),
BinaryTree::NodesInLevel(depth) * split_proposals.histogram_size * num_outputs * 2);
this->Scan(depth, tree);
}

void Scan(int depth, Tree& tree)
{
auto scan_node_histogram = [&](int node_idx) {
for (int feature = 0; feature < num_features; feature++) {
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int output = 0; output < num_outputs; output++) {
GPair sum = {0.0, 0.0};
for (int bin_idx = 0; bin_idx < samples_per_feature; bin_idx++) {
sum += histogram_buffer[{node_idx, feature, output, bin_idx}];
histogram_buffer[{node_idx, feature, output, bin_idx}] = sum;
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
sum += histogram_buffer[{node_idx, bin_idx, output}];
histogram_buffer[{node_idx, bin_idx, output}] = sum;
}
}
}
Expand All @@ -177,12 +224,12 @@ struct TreeBuilder {
auto subtract_node_histogram =
[&](int subtract_node_idx, int scanned_node_idx, int parent_node_idx) {
for (int feature = 0; feature < num_features; feature++) {
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int output = 0; output < num_outputs; output++) {
for (int bin_idx = 0; bin_idx < samples_per_feature; bin_idx++) {
auto scanned_sum = histogram_buffer[{scanned_node_idx, feature, output, bin_idx}];
auto parent_sum = histogram_buffer[{parent_node_idx, feature, output, bin_idx}];
histogram_buffer[{subtract_node_idx, feature, output, bin_idx}] =
parent_sum - scanned_sum;
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
auto scanned_sum = histogram_buffer[{scanned_node_idx, bin_idx, output}];
auto parent_sum = histogram_buffer[{parent_node_idx, bin_idx, output}];
histogram_buffer[{subtract_node_idx, bin_idx, output}] = parent_sum - scanned_sum;
}
}
}
Expand All @@ -201,22 +248,19 @@ struct TreeBuilder {
subtract_node_histogram(subtract_node_idx, histogram_node_idx, parent_id);
}
}
template <typename TYPE>
void PerformBestSplit(int depth,
Tree& tree,
legate::AccessorRO<TYPE, 2> split_proposal,
double alpha)
void PerformBestSplit(int depth, Tree& tree, double alpha)
{
for (int node_id = BinaryTree::LevelBegin(depth); node_id < BinaryTree::LevelBegin(depth + 1);
node_id++) {
double best_gain = 0;
int best_feature = -1;
int best_bin = -1;
for (int feature = 0; feature < num_features; feature++) {
for (int bin_idx = 0; bin_idx < samples_per_feature; bin_idx++) {
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
double gain = 0;
for (int output = 0; output < num_outputs; ++output) {
auto [G_L, H_L] = histogram_buffer[{node_id, feature, output, bin_idx}];
auto [G_L, H_L] = histogram_buffer[{node_id, bin_idx, output}];
auto G = tree.gradient[{node_id, output}];
auto H = tree.hessian[{node_id, output}];
auto G_R = G - G_L;
Expand All @@ -240,7 +284,7 @@ struct TreeBuilder {
std::vector<double> hessian_left(num_outputs);
std::vector<double> hessian_right(num_outputs);
for (int output = 0; output < num_outputs; ++output) {
auto [G_L, H_L] = histogram_buffer[{node_id, best_feature, output, best_bin}];
auto [G_L, H_L] = histogram_buffer[{node_id, best_bin, output}];
auto G = tree.gradient[{node_id, output}];
auto H = tree.hessian[{node_id, output}];
auto G_R = G - G_L;
Expand All @@ -255,7 +299,7 @@ struct TreeBuilder {
if (hessian_left[0] <= 0.0 || hessian_right[0] <= 0.0) continue;
tree.AddSplit(node_id,
best_feature,
split_proposal[{best_feature, best_bin}],
split_proposals.split_proposals[{best_bin}],
left_leaf,
right_leaf,
best_gain,
Expand Down Expand Up @@ -314,8 +358,8 @@ struct TreeBuilder {
const int32_t num_features;
const int32_t num_outputs;
const int32_t max_nodes;
const int32_t samples_per_feature;
legate::Buffer<GPair, 4> histogram_buffer;
SparseSplitProposals<T> split_proposals;
legate::Buffer<GPair, 3> histogram_buffer;
};

struct build_tree_fn {
Expand All @@ -325,41 +369,37 @@ struct build_tree_fn {
auto [X, X_shape, X_accessor] = GetInputStore<T, 3>(context.input(0).data());
auto [g, g_shape, g_accessor] = GetInputStore<double, 3>(context.input(1).data());
auto [h, h_shape, h_accessor] = GetInputStore<double, 3>(context.input(2).data());
auto [split_proposals, split_proposals_shape, split_proposals_accessor] =
GetInputStore<T, 2>(context.input(3).data());
EXPECT_DENSE_ROW_MAJOR(X_accessor.accessor, X_shape);
EXPECT_DENSE_ROW_MAJOR(split_proposals_accessor.accessor, split_proposals_shape);
auto num_features = X_shape.hi[1] - X_shape.lo[1] + 1;
auto num_rows = std::max<int64_t>(X_shape.hi[0] - X_shape.lo[0] + 1, 0);
EXPECT_AXIS_ALIGNED(0, X_shape, g_shape);
EXPECT_AXIS_ALIGNED(0, g_shape, h_shape);
EXPECT_AXIS_ALIGNED(1, g_shape, h_shape);
auto num_outputs = g.shape<3>().hi[2] - g.shape<3>().lo[2] + 1;
EXPECT_IS_BROADCAST(split_proposals_shape);
auto samples_per_feature = split_proposals_shape.hi[1] - split_proposals_shape.lo[1] + 1;
EXPECT(g_shape.lo[2] == 0, "Expect all outputs to be present");

// Scalars
auto max_depth = context.scalars().at(0).value<int>();
auto max_nodes = context.scalars().at(1).value<int>();
auto alpha = context.scalars().at(2).value<double>();
auto max_depth = context.scalars().at(0).value<int>();
auto max_nodes = context.scalars().at(1).value<int>();
auto alpha = context.scalars().at(2).value<double>();
auto split_samples = context.scalars().at(3).value<int>();
auto seed = context.scalars().at(4).value<int>();
auto dataset_rows = context.scalars().at(5).value<int64_t>();

Tree tree(max_nodes, num_outputs);
SparseSplitProposals<T> split_proposals =
SelectSplitSamples(context, X_accessor, X_shape, split_samples, seed, dataset_rows);

// Begin building the tree
TreeBuilder tree_builder(num_rows, num_features, num_outputs, max_nodes, samples_per_feature);
TreeBuilder<T> tree_builder(num_rows, num_features, num_outputs, max_nodes, split_proposals);

tree_builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha);
for (int64_t depth = 0; depth < max_depth; ++depth) {
tree_builder.UpdatePositions(depth, tree, X_accessor, X_shape);

tree_builder.ComputeHistogram(depth,
context,
tree,
X_accessor,
X_shape,
split_proposals_accessor,
g_accessor,
h_accessor);
tree_builder.PerformBestSplit(depth, tree, split_proposals_accessor, alpha);
tree_builder.ComputeHistogram(
depth, context, tree, X_accessor, X_shape, g_accessor, h_accessor);
tree_builder.PerformBestSplit(depth, tree, alpha);
}

WriteTreeOutput(context, tree);
Expand Down
Loading

0 comments on commit 157b427

Please sign in to comment.