From e5e47c3c998e9bb264ab8d690694f0371cbef459 Mon Sep 17 00:00:00 2001
From: Jiaming Yuan <jm.yuan@outlook.com>
Date: Thu, 13 Jan 2022 16:11:52 +0800
Subject: [PATCH] Clarify the behavior of invalid categorical value handling.
 (#7529)

---
 doc/tutorials/categorical.rst        | 12 ++++++++
 src/common/categorical.h             | 28 +++++++++++-------
 src/common/quantile.cu               | 10 +++----
 src/predictor/predict_fn.h           | 12 ++++----
 src/tree/updater_approx.h            |  2 +-
 src/tree/updater_gpu_hist.cu         |  6 ++--
 tests/cpp/common/test_categorical.cc | 43 ++++++++++++++++++++++++++++
 7 files changed, 88 insertions(+), 25 deletions(-)
 create mode 100644 tests/cpp/common/test_categorical.cc

diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst
index a56b946476ba..dd30a6ec4397 100644
--- a/doc/tutorials/categorical.rst
+++ b/doc/tutorials/categorical.rst
@@ -108,6 +108,18 @@ feature it's specified as ``"c"``.  The Dask module in XGBoost has the same inte
 :class:`dask.Array <dask.Array>` can also be used as categorical data.
 
 
+*************
+Miscellaneous
+*************
+
+By default, XGBoost assumes input categories are integers starting from 0 till the number
+of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid
+values due to mistakes or missing values. It can be negative value, floating point value
+that can not be represented by 32-bit integer, or values that are larger than actual
+number of unique categories.  During training this is validated but for prediction it's
+treated as the same as missing value for performance reasons.  Lastly, missing values are
+treated as the same as numerical features.
+
 **********
 Next Steps
 **********
diff --git a/src/common/categorical.h b/src/common/categorical.h
index 4cbbbf72ba60..e1d4d2c2a44c 100644
--- a/src/common/categorical.h
+++ b/src/common/categorical.h
@@ -5,6 +5,8 @@
 #ifndef XGBOOST_COMMON_CATEGORICAL_H_
 #define XGBOOST_COMMON_CATEGORICAL_H_
 
+#include <limits>
+
 #include "bitfield.h"
 #include "xgboost/base.h"
 #include "xgboost/data.h"
@@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span<FeatureType const> ft, bst_feature_t fidx)
   return !ft.empty() && ft[fidx] == FeatureType::kCategorical;
 }
 
+
+inline XGBOOST_DEVICE bool InvalidCat(float cat) {
+  return cat < 0 || cat > static_cast<float>(std::numeric_limits<bst_cat_t>::max());
+}
+
 /* \brief Whether should it traverse to left branch of a tree.
  *
  *  For one hot split, go to left if it's NOT the matching category.
  */
-inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t cat) {
-  auto pos = CLBitField32::ToBitPos(cat);
-  if (pos.int_pos >= cats.size()) {
-    return true;
-  }
+template <bool validate = true>
+inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
   CLBitField32 const s_cats(cats);
-  return !s_cats.Check(cat);
+  // FIXME: Size() is not accurate since it represents the size of bit set instead of
+  // actual number of categories.
+  if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
+    return dft_left;
+  }
+  return !s_cats.Check(AsCat(cat));
 }
 
 inline void InvalidCategory() {
   LOG(FATAL) << "Invalid categorical value detected.  Categorical value "
-                "should be non-negative.";
+                "should be non-negative, less than maximum size of int32 and less than total "
+                "number of categories in training data.";
 }
 
 /*!
@@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task)
 }
 
 struct IsCatOp {
-  XGBOOST_DEVICE bool operator()(FeatureType ft) {
-    return ft == FeatureType::kCategorical;
-  }
+  XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
 };
 
 using CatBitField = LBitField32;
diff --git a/src/common/quantile.cu b/src/common/quantile.cu
index d89951915d4f..d15d310c0516 100644
--- a/src/common/quantile.cu
+++ b/src/common/quantile.cu
@@ -581,14 +581,14 @@ void SketchContainer::AllReduce() {
 }
 
 namespace {
-struct InvalidCat {
+struct InvalidCatOp {
   Span<float const> values;
   Span<uint32_t const> ptrs;
   Span<FeatureType const> ft;
 
   XGBOOST_DEVICE bool operator()(size_t i) {
     auto fidx = dh::SegmentId(ptrs, i);
-    return IsCat(ft, fidx) && values[i] < 0;
+    return IsCat(ft, fidx) && InvalidCat(values[i]);
   }
 };
 }  // anonymous namespace
@@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
     dh::XGBCachingDeviceAllocator<char> alloc;
     auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
     auto it = thrust::make_counting_iterator(0ul);
+
     CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
-    auto invalid =
-        thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
-                       InvalidCat{out_cut_values, ptrs, d_ft});
+    auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
+                                  InvalidCatOp{out_cut_values, ptrs, d_ft});
     if (invalid) {
       InvalidCategory();
     }
diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h
index 1547d6e774ae..7ce474023e8a 100644
--- a/src/predictor/predict_fn.h
+++ b/src/predictor/predict_fn.h
@@ -9,16 +9,16 @@
 namespace xgboost {
 namespace predictor {
 template <bool has_missing, bool has_categorical>
-inline XGBOOST_DEVICE bst_node_t
-GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
-            bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
+inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
+                                             float fvalue, bool is_missing,
+                                             RegTree::CategoricalSplitMatrix const &cats) {
   if (has_missing && is_missing) {
     return node.DefaultChild();
   } else {
     if (has_categorical && common::IsCat(cats.split_type, nid)) {
-      auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
-                                                     cats.node_ptr[nid].size);
-      return Decision(node_categories, common::AsCat(fvalue))
+      auto node_categories =
+          cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
+      return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
                  ? node.LeftChild()
                  : node.RightChild();
     } else {
diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h
index 5e16f568f80b..158ab2b2c12a 100644
--- a/src/tree/updater_approx.h
+++ b/src/tree/updater_approx.h
@@ -95,7 +95,7 @@ class ApproxRowPartitioner {
             auto node_cats = categories.subspan(segment.beg, segment.size);
             bool go_left = true;
             if (is_cat) {
-              go_left = common::Decision(node_cats, common::AsCat(cut_value));
+              go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
             } else {
               go_left = cut_value <= candidate.split.split_value;
             }
diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu
index 48d58074ef19..199be0a4c5d1 100644
--- a/src/tree/updater_gpu_hist.cu
+++ b/src/tree/updater_gpu_hist.cu
@@ -396,7 +396,7 @@ struct GPUHistMakerDevice {
           } else {
             bool go_left = true;
             if (split_type == FeatureType::kCategorical) {
-              go_left = common::Decision(node_cats, common::AsCat(cut_value));
+              go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
             } else {
               go_left = cut_value <= split_node.SplitCond();
             }
@@ -474,7 +474,7 @@ struct GPUHistMakerDevice {
                 auto node_cats =
                     categories.subspan(categories_segments[position].beg,
                                        categories_segments[position].size);
-                go_left = common::Decision(node_cats, common::AsCat(element));
+                go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
               } else {
                 go_left = element <= node.SplitCond();
               }
@@ -573,7 +573,7 @@ struct GPUHistMakerDevice {
       CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
           << "Categorical feature value too large.";
       auto cat = common::AsCat(candidate.split.fvalue);
-      if (cat < 0) {
+      if (common::InvalidCat(cat)) {
         common::InvalidCategory();
       }
       std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0);
diff --git a/tests/cpp/common/test_categorical.cc b/tests/cpp/common/test_categorical.cc
new file mode 100644
index 000000000000..cc8eb0f7e6c4
--- /dev/null
+++ b/tests/cpp/common/test_categorical.cc
@@ -0,0 +1,43 @@
+/*!
+ * Copyright 2021 by XGBoost Contributors
+ */
+#include <gtest/gtest.h>
+
+#include <limits>
+
+#include "../../../src/common/categorical.h"
+
+namespace xgboost {
+namespace common {
+TEST(Categorical, Decision) {
+  // inf
+  float a = std::numeric_limits<float>::infinity();
+
+  ASSERT_TRUE(common::InvalidCat(a));
+  std::vector<uint32_t> cats(256, 0);
+  ASSERT_TRUE(Decision(cats, a, true));
+
+  // larger than size
+  a = 256;
+  ASSERT_TRUE(Decision(cats, a, true));
+
+  // negative
+  a = -1;
+  ASSERT_TRUE(Decision(cats, a, true));
+
+  CatBitField bits{cats};
+  bits.Set(0);
+  a = -0.5;
+  ASSERT_TRUE(Decision(cats, a, true));
+
+  // round toward 0
+  a = 0.5;
+  ASSERT_FALSE(Decision(cats, a, true));
+
+  // valid
+  a = 13;
+  bits.Set(a);
+  ASSERT_FALSE(Decision(bits.Bits(), a, true));
+}
+}  // namespace common
+}  // namespace xgboost