Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split evaluator using device-wide Scan primitive #7197

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b100e8e
Use device-wise Scan to evaluate splits
hcho3 Aug 14, 2021
26244f6
Remove dead code
hcho3 Aug 14, 2021
ac530a5
Add __noinline__ to reduce build time
hcho3 Aug 24, 2021
b32a634
Add debug logging
hcho3 Sep 7, 2021
09880b3
Fix bug
hcho3 Sep 8, 2021
11c8bcd
Revert "Add debug logging"
hcho3 Sep 8, 2021
115b649
Don't use zip iterator / tuple
hcho3 Sep 9, 2021
71a5fb7
Revert "Don't use zip iterator / tuple"
hcho3 Sep 10, 2021
0931023
Compute partial_sum before computing loss_chg
hcho3 Sep 10, 2021
41a9a38
Fix memory error
hcho3 Sep 11, 2021
18d0daf
Fix bug with computing partial_sum
hcho3 Sep 14, 2021
6cfac2f
Enforce min_child_weight
hcho3 Sep 14, 2021
bb904be
Remove __noinline__ directive
hcho3 Sep 14, 2021
1242a7c
Fix lint
hcho3 Sep 14, 2021
d26414e
Set threshold (fvalue) correctly when performing backward scan
hcho3 Sep 15, 2021
ce86fdc
Correctly set left_sum / right_sum for categoricals
hcho3 Sep 15, 2021
7b8bf95
Better tie-breaking: favor lower featureID, kLeftDir
hcho3 Sep 15, 2021
e7ac53e
Set threshold (fvalue) correctly for categoricals
hcho3 Sep 15, 2021
ac9c79b
Remove superfluous DoIt() functions
hcho3 Sep 15, 2021
cd98c39
Merge remote-tracking branch 'origin/master' into rework-evaluation
hcho3 Sep 15, 2021
c96bb6a
Set threshold (fvalue) correctly when performing backward scan
hcho3 Sep 15, 2021
eadb3ce
Refine tie-breaking
hcho3 Sep 21, 2021
7edef3c
Revert "Refine tie-breaking"
hcho3 Sep 21, 2021
447503d
Merge remote-tracking branch 'origin/master' into rework-evaluation
hcho3 Sep 21, 2021
1077d5d
Relax tolerance in tests
hcho3 Sep 22, 2021
b954d21
Relax MultiClassesSerializationTest.GpuHist
hcho3 Sep 23, 2021
013c733
Relax GpuHist.EvaluateCategoricalSplit
hcho3 Sep 23, 2021
035ed20
Merge remote-tracking branch 'origin/master' into rework-evaluation
hcho3 Sep 23, 2021
d8b7d75
Explicitly specifiy eval_metric in TestDistributedGPU.test_early_stop…
hcho3 Sep 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demo/guide-python/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main() -> None:
reg_enc_results = np.array(reg_enc.evals_result()["validation_0"]["rmse"])

# Check that they have same results
np.testing.assert_allclose(reg_results, reg_enc_results)
np.testing.assert_allclose(reg_results, reg_enc_results, rtol=1e-3)

# Convert to DMatrix for SHAP value
booster: xgb.Booster = reg.get_booster()
Expand Down
521 changes: 237 additions & 284 deletions src/tree/gpu_hist/evaluate_splits.cu

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions src/tree/gpu_hist/evaluate_splits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#define EVALUATE_SPLITS_CUH_
#include <xgboost/span.h>
#include "../../data/ellpack_page.cuh"
#include "../../common/device_helpers.cuh"
#include "../split_evaluator.h"
#include "../constraints.cuh"
#include "../updater_gpu_common.cuh"
Expand All @@ -24,11 +25,87 @@ struct EvaluateSplitInputs {
common::Span<const float> min_fvalue;
common::Span<const GradientSumT> gradient_histogram;
};

struct EvaluateSplitsHistEntry {
uint32_t node_idx;
uint32_t hist_idx;
bool forward;
};

template <typename GradientSumT>
struct ScanElem {
uint32_t node_idx;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bst_node_t.

uint32_t hist_idx;
int32_t findex{-1};
float fvalue{std::numeric_limits<float>::quiet_NaN()};
bool is_cat{false};
bool forward{true};
GradientSumT gpair{0.0, 0.0};
GradientSumT partial_sum{0.0, 0.0};
GradientSumT parent_sum{0.0, 0.0};
};

template <typename GradientSumT>
struct ScanValueOp {
EvaluateSplitInputs<GradientSumT> left;
EvaluateSplitInputs<GradientSumT> right;
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator;

using ScanElemT = ScanElem<GradientSumT>;

__device__ ScanElemT MapEvaluateSplitsHistEntryToScanElem(
EvaluateSplitsHistEntry entry,
EvaluateSplitInputs<GradientSumT> split_input);
__device__ ScanElemT
operator() (EvaluateSplitsHistEntry entry);
};

template <typename GradientSumT>
struct ScanOp {
EvaluateSplitInputs<GradientSumT> left, right;
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator;

using ScanElemT = ScanElem<GradientSumT>;

__device__ ScanElemT operator() (ScanElemT lhs, ScanElemT rhs);
};

template <typename GradientSumT>
struct ReduceElem {
GradientSumT partial_sum{0.0, 0.0};
GradientSumT parent_sum{0.0, 0.0};
float loss_chg{std::numeric_limits<float>::lowest()};
int32_t findex{-1};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bst_note_t.

uint32_t node_idx{0};
float fvalue{std::numeric_limits<float>::quiet_NaN()};
bool is_cat{false};
DefaultDirection direction{DefaultDirection::kLeftDir};
};

template <typename GradientSumT>
struct ReduceValueOp {
EvaluateSplitInputs<GradientSumT> left, right;
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator;

using ScanElemT = ScanElem<GradientSumT>;
using ReduceElemT = ReduceElem<GradientSumT>;

__device__ ReduceElemT operator() (ScanElemT e);
};

template <typename GradientSumT>
dh::device_vector<ReduceElem<GradientSumT>>
EvaluateSplitsGenerateSplitCandidatesViaScan(
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right);

template <typename GradientSumT>
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right);

template <typename GradientSumT>
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
Expand Down
12 changes: 9 additions & 3 deletions src/tree/updater_gpu_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static const bst_node_t kUnusedNode = -1;
* @enum DefaultDirection node.cuh
* @brief Default direction to be followed in case of missing values
*/
enum DefaultDirection {
enum class DefaultDirection : uint8_t {
/** move to left child */
kLeftDir = 0,
/** move to right child */
Expand All @@ -56,7 +56,7 @@ enum DefaultDirection {

struct DeviceSplitCandidate {
float loss_chg {-FLT_MAX};
DefaultDirection dir {kLeftDir};
DefaultDirection dir {DefaultDirection::kLeftDir};
int findex {-1};
float fvalue {0};
bool is_cat { false };
Expand All @@ -66,6 +66,12 @@ struct DeviceSplitCandidate {

XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT

XGBOOST_DEVICE DeviceSplitCandidate(float loss_chg, DefaultDirection dir, int findex,
float fvalue, bool is_cat, GradientPair left_sum,
GradientPair right_sum)
: loss_chg(loss_chg), dir(dir), findex(findex), fvalue(fvalue), is_cat(is_cat),
left_sum(left_sum), right_sum(right_sum) {}

template <typename ParamT>
XGBOOST_DEVICE void Update(const DeviceSplitCandidate& other,
const ParamT& param) {
Expand Down Expand Up @@ -98,7 +104,7 @@ struct DeviceSplitCandidate {

friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
os << "loss_chg:" << c.loss_chg << ", "
<< "dir: " << c.dir << ", "
<< "dir: " << (c.dir == DefaultDirection::kLeftDir ? "left" : "right") << ", "
<< "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", "
<< "is_cat: " << c.is_cat << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,13 @@ struct GPUHistMakerDevice {
dh::CopyToD(split_cats, &node_categories);
tree.ExpandCategorical(
candidate.nid, candidate.split.findex, split_cats,
candidate.split.dir == kLeftDir, base_weight, left_weight,
candidate.split.dir == DefaultDirection::kLeftDir, base_weight, left_weight,
right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
candidate.split.fvalue, candidate.split.dir == DefaultDirection::kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(),
Expand Down
8 changes: 7 additions & 1 deletion tests/cpp/test_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ void CompareJSON(Json l, Json r) {

for (auto const& kv : l_obj) {
ASSERT_NE(r_obj.find(kv.first), r_obj.cend());
CompareJSON(l_obj.at(kv.first), r_obj.at(kv.first));
if (kv.first == "default_left" || kv.first == "split_conditions") {
auto const& l_arr = get<Array const>(l_obj.at(kv.first));
auto const& r_arr = get<Array const>(r_obj.at(kv.first));
ASSERT_EQ(l_arr.size(), r_arr.size());
} else {
CompareJSON(l_obj.at(kv.first), r_obj.at(kv.first));
}
Comment on lines +35 to +41
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change allows two models to differ in default_left and split_conditions but still require them to produce idential loss_chg. The relaxation is necessary because the test produced two split candidates with identical loss_chg and findex but different default_left and split_conditions. The given scenario occurred because there was no missing values for that particular feature.

}
break;
}
Expand Down
8 changes: 6 additions & 2 deletions tests/cpp/tree/gpu_hist/test_evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ void TestEvaluateSingleSplit(bool is_categorical) {

DeviceSplitCandidate result = out_splits[0];
EXPECT_EQ(result.findex, 1);
EXPECT_EQ(result.fvalue, 11.0);
if (is_categorical) {
EXPECT_TRUE(result.fvalue == 11.0 || result.fvalue == 12.0);
} else {
EXPECT_EQ(result.fvalue, 11.0);
}
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
parent_sum.GetGrad());
EXPECT_FLOAT_EQ(result.left_sum.GetHess() + result.right_sum.GetHess(),
Expand Down Expand Up @@ -103,7 +107,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
DeviceSplitCandidate result = out_splits[0];
EXPECT_EQ(result.findex, 0);
EXPECT_EQ(result.fvalue, 1.0);
EXPECT_EQ(result.dir, kRightDir);
EXPECT_EQ(result.dir, DefaultDirection::kRightDir);
EXPECT_EQ(result.left_sum, GradientPair(-0.5, 0.5));
EXPECT_EQ(result.right_sum, GradientPair(1.5, 1.0));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/tree/test_tree_stat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
tree.WalkTree([&tree](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
// 1.0 is the default `min_child_weight`.
CHECK_GE(tree.Stat(nidx).sum_hess, 1.0);
EXPECT_GE(tree.Stat(nidx).sum_hess, 1.0);
}
return true;
});
Expand Down
1 change: 1 addition & 0 deletions tests/python-gpu/test_gpu_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def test_early_stopping(self, local_cuda_cluster: LocalCUDACluster) -> None:
n_estimators=100)
cls.client = client
cls.fit(X, y, early_stopping_rounds=early_stopping_rounds,
eval_metric='error',
eval_set=[(valid_X, valid_y)])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
Expand Down