Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify the behavior of invalid categorical value handling. #7529

Merged
merged 5 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/tutorials/categorical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte
:class:`dask.Array <dask.Array>` can also be used as categorical data.


*************
Miscellaneous
*************

By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, floating point value
that can not be represented by 32-bit integer, or values that are larger than actual
number of unique categories. During training this is validated but for prediction it's
treated as the same as missing value for performance reasons. Lastly, missing values are
treated as the same as numerical features.

**********
Next Steps
**********
Expand Down
28 changes: 18 additions & 10 deletions src/common/categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
#define XGBOOST_COMMON_CATEGORICAL_H_

#include <limits>

#include "bitfield.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
Expand All @@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
}


inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
}

/* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
*/
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
auto pos = CLBitField32::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
CLBitField32 const s_cats(cats);
return !s_cats.Check(cat);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
}
return !s_cats.Check(AsCat(cat));
}

inline void InvalidCategory() {
LOG(FATAL) << "Invalid categorical value detected. Categorical value "
"should be non-negative.";
"should be non-negative, less than maximum size of int32 and less than total "
"number of categories in training data.";
}

/*!
Expand All @@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
}

struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
};

using CatBitField = LBitField32;
Expand Down
10 changes: 5 additions & 5 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
}

namespace {
struct InvalidCat {
struct InvalidCatOp {
Span<float const> values;
Span<uint32_t const> ptrs;
Span<FeatureType const> ft;

XGBOOST_DEVICE bool operator()(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && values[i] < 0;
return IsCat(ft, fidx) && InvalidCat(values[i]);
}
};
} // anonymous namespace
Expand Down Expand Up @@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
auto it = thrust::make_counting_iterator(0ul);

CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
auto invalid =
thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCat{out_cut_values, ptrs, d_ft});
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
InvalidCatOp{out_cut_values, ptrs, d_ft});
if (invalid) {
InvalidCategory();
}
Expand Down
12 changes: 6 additions & 6 deletions src/predictor/predict_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
namespace xgboost {
namespace predictor {
template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
float fvalue, bool is_missing,
RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) {
return node.DefaultChild();
} else {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
cats.node_ptr[nid].size);
return Decision(node_categories, common::AsCat(fvalue))
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild()
: node.RightChild();
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ApproxRowPartitioner {
auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true;
if (is_cat) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else {
go_left = cut_value <= candidate.split.split_value;
}
Expand Down
6 changes: 3 additions & 3 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
} else {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision(node_cats, common::AsCat(cut_value));
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
} else {
go_left = cut_value <= split_node.SplitCond();
}
Expand Down Expand Up @@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision(node_cats, common::AsCat(element));
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
Expand Down Expand Up @@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
auto cat = common::AsCat(candidate.split.fvalue);
if (cat < 0) {
if (common::InvalidCat(cat)) {
common::InvalidCategory();
}
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
Expand Down
43 changes: 43 additions & 0 deletions tests/cpp/common/test_categorical.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include <gtest/gtest.h>

#include <limits>

#include "../../../src/common/categorical.h"

namespace xgboost {
namespace common {
TEST(Categorical, Decision) {
// inf
float a = std::numeric_limits<float>::infinity();

ASSERT_TRUE(common::InvalidCat(a));
std::vector<uint32_t> cats(256, 0);
ASSERT_TRUE(Decision(cats, a, true));

// larger than size
a = 256;
ASSERT_TRUE(Decision(cats, a, true));

// negative
a = -1;
ASSERT_TRUE(Decision(cats, a, true));

CatBitField bits{cats};
bits.Set(0);
a = -0.5;
ASSERT_TRUE(Decision(cats, a, true));

// round toward 0
a = 0.5;
ASSERT_FALSE(Decision(cats, a, true));

// valid
a = 13;
bits.Set(a);
ASSERT_FALSE(Decision(bits.Bits(), a, true));
}
} // namespace common
} // namespace xgboost