Skip to content

Commit

Permalink
Clarify the behavior of invalid categorical value handling. (#7529)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 13, 2022
1 parent 20c0d60 commit e5e47c3
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 25 deletions.
12 changes: 12 additions & 0 deletions doc/tutorials/categorical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte
:class:`dask.Array <dask.Array>` can also be used as categorical data.


*************
Miscellaneous
*************

By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, floating point value
that can not be represented by 32-bit integer, or values that are larger than actual
number of unique categories. During training this is validated but for prediction it's
treated as the same as missing value for performance reasons. Lastly, missing values are
treated as the same as numerical features.

**********
Next Steps
**********
Expand Down
28 changes: 18 additions & 10 deletions src/common/categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_

#include <limits>

#include "bitfield.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
Expand All @@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
}


inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
}

/* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
*/
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
auto pos = CLBitField32::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
CLBitField32 const s_cats(cats);
return !s_cats.Check(cat);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
}
return !s_cats.Check(AsCat(cat));
}

inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
"should be non-negative.";
"should be non-negative, less than maximum size of int32 and less than total "
"number of categories in training data.";
}

/*!
Expand All @@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
}

struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};

using CatBitField = LBitField32;
Expand Down
10 changes: 5 additions & 5 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
}

namespace {
struct InvalidCat {
struct InvalidCatOp {
Span<float const> values;
Span<uint32_t const> ptrs;
Span<FeatureType const> ft;

XGBOOST_DEVICE bool operator()(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && values[i] < 0;
return IsCat(ft, fidx) && InvalidCat(values[i]);
}
};
} // anonymous namespace
Expand Down Expand Up @@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
auto it = thrust::make_counting_iterator(0ul);

CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
auto invalid =
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCat{out_cut_values, ptrs, d_ft});
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCatOp{out_cut_values, ptrs, d_ft});
if (invalid) {
InvalidCategory();
}
Expand Down
12 changes: 6 additions & 6 deletions src/predictor/predict_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
namespace xgboost {
namespace predictor {
template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
float fvalue, bool is_missing,
RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) {
return node.DefaultChild();
} else {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
cats.node_ptr[nid].size);
return Decision(node_categories, common::AsCat(fvalue))
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild()
: node.RightChild();
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ApproxRowPartitioner {
auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true;
if (is_cat) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else {
go_left = cut_value <= candidate.split.split_value;
}
Expand Down
6 changes: 3 additions & 3 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
} else {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
} else {
go_left = cut_value <= split_node.SplitCond();
}
Expand Down Expand Up @@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element));
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
Expand Down Expand Up @@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue);
if (cat < 0) {
if (common::InvalidCat(cat)) {
common::InvalidCategory();
}
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
Expand Down
43 changes: 43 additions & 0 deletions tests/cpp/common/test_categorical.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>

#include <limits>

#include "../../../src/common/categorical.h"

namespace xgboost {
namespace common {
TEST(Categorical, Decision) {
// inf
float a = std::numeric_limits<float>::infinity();

ASSERT_TRUE(common::InvalidCat(a));
std::vector<uint32_t> cats(256, 0);
ASSERT_TRUE(Decision(cats, a, true));

// larger than size
a = 256;
ASSERT_TRUE(Decision(cats, a, true));

// negative
a = -1;
ASSERT_TRUE(Decision(cats, a, true));

CatBitField bits{cats};
bits.Set(0);
a = -0.5;
ASSERT_TRUE(Decision(cats, a, true));

// round toward 0
a = 0.5;
ASSERT_FALSE(Decision(cats, a, true));

// valid
a = 13;
bits.Set(a);
ASSERT_FALSE(Decision(bits.Bits(), a, true));
}
} // namespace common
} // namespace xgboost

0 comments on commit e5e47c3

Please sign in to comment.