Skip to content

Commit

Permalink
Fix approx test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 29, 2021
1 parent f4b8f31 commit d690afb
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions tests/cpp/tree/test_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@

namespace xgboost {
namespace tree {
namespace {
void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // anonymous namespace

TEST(Approx, Partitioner) {
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
ApproxRowPartitioner partitioner{n_samples, base_rowid};
Expand All @@ -20,20 +33,18 @@ TEST(Approx, Partitioner) {
ctx.InitAllowUnknown(Args{});
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};

for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 64})) {
bst_feature_t split_ind = 0;
auto grad = GenerateRandomGradients(n_samples);
std::vector<float> hess(grad.Size());
std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(),
[](auto gpair) { return gpair.GetHess(); });

for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 64, hess})) {
bst_feature_t const split_ind = 0;
{
auto min_value = page.cut.MinValues()[split_ind];
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/min_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
ApproxRowPartitioner partitioner{n_samples, base_rowid};
candidates.front().split.split_value = min_value;
candidates.front().split.sindex = 0;
candidates.front().split.sindex |= (1U << 31);
GetSplit(&tree, min_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
ASSERT_EQ(partitioner.Size(), 3);
ASSERT_EQ(partitioner[1].Size(), 0);
Expand All @@ -44,16 +55,8 @@ TEST(Approx, Partitioner) {
auto ptr = page.cut.Ptrs()[split_ind + 1];
float split_value = page.cut.Values().at(ptr / 2);
RegTree tree;
tree.ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/split_ind,
/*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild();
candidates.front().split.split_value = split_value;
candidates.front().split.sindex = 0;
candidates.front().split.sindex |= (1U << 31);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);

auto elem = partitioner[left_nidx];
Expand Down

0 comments on commit d690afb

Please sign in to comment.