Skip to content

Commit

Permalink
Support categorical data for hist. (#7695)
Browse files Browse the repository at this point in the history
* Extract partitioner from hist.
* Implement categorical data support by passing the gradient index directly into the partitioner.
* Organize/update document.
* Remove code for negative hessian.
  • Loading branch information
trivialfis authored Feb 24, 2022
1 parent f60d95b commit 83a66b4
Show file tree
Hide file tree
Showing 15 changed files with 403 additions and 499 deletions.
7 changes: 2 additions & 5 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,6 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method

- Use single precision to build histograms instead of double precision.

Additional parameters for ``approx`` and ``gpu_hist`` tree method
=================================================================

* ``max_cat_to_onehot``

.. versionadded:: 1.6
Expand All @@ -256,8 +253,8 @@ Additional parameters for ``approx`` and ``gpu_hist`` tree method
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification. Also, `approx` or `gpu_hist`
tree method is required.
Only relevant for regression and binary classification. Also, ``exact`` tree method is
not supported

Additional parameters for Dart Booster (``booster=dart``)
=========================================================
Expand Down
83 changes: 43 additions & 40 deletions doc/tutorials/categorical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ Categorical Data

.. note::

As of XGBoost 1.6, the feature is highly experimental and has limited features
As of XGBoost 1.6, the feature is experimental and has limited features

Starting from version 1.5, XGBoost has experimental support for categorical data available
for public testing. At the moment, the support is implemented as one-hot encoding based
categorical tree splits. For numerical data, the split condition is defined as
:math:`value < threshold`, while for categorical data the split is defined as :math:`value
== category` and ``category`` is a discrete value. More advanced categorical split
strategy is planned for future releases and this tutorial details how to inform XGBoost
about the data type. Also, the current support for training is limited to ``gpu_hist``
tree method.
for public testing. For numerical data, the split condition is defined as :math:`value <
threshold`, while for categorical data the split is defined depending on whether
partitioning or onehot encoding is used. For partition-based splits, the splits are
specified as :math:`value \in categories`, where ``categories`` is the set of categories
in one feature. If onehot encoding is used instead, then the split is defined as
:math:`value == category`. More advanced categorical split strategy is planned for future
releases and this tutorial details how to inform XGBoost about the data type.

************************************
Training with scikit-learn Interface
Expand All @@ -35,13 +35,13 @@ parameter ``enable_categorical``:

.. code:: python
# Only gpu_hist is supported for categorical data as mentioned previously
# Supported tree methods are `gpu_hist`, `approx`, and `hist`.
clf = xgb.XGBClassifier(
tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False
)
# X is the dataframe we created in previous snippet
clf.fit(X, y)
# Must use JSON for serialization, otherwise the information is lost
# Must use JSON/UBJSON for serialization, otherwise the information is lost.
clf.save_model("categorical-model.json")
Expand All @@ -60,11 +60,37 @@ can plot the model and calculate the global feature importance:
The ``scikit-learn`` interface from dask is similar to single node version. The basic
idea is create dataframe with category feature type, and tell XGBoost to use ``gpu_hist``
with parameter ``enable_categorical``. See :ref:`sphx_glr_python_examples_categorical.py`
for a worked example of using categorical data with ``scikit-learn`` interface. A
comparison between using one-hot encoded data and XGBoost's categorical data support can
be found :ref:`sphx_glr_python_examples_cat_in_the_dat.py`.
idea is create dataframe with category feature type, and tell XGBoost to use it by setting
the ``enable_categorical`` parameter. See :ref:`sphx_glr_python_examples_categorical.py`
for a worked example of using categorical data with ``scikit-learn`` interface with
one-hot encoding. A comparison between using one-hot encoded data and XGBoost's
categorical data support can be found :ref:`sphx_glr_python_examples_cat_in_the_dat.py`.


********************
Optimal Partitioning
********************

.. versionadded:: 1.6

Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details. When objective is not
regression or binary classification, XGBoost will fallback to using onehot encoding
instead.


**********************
Expand All @@ -82,7 +108,7 @@ categorical data, we need to pass the similar parameter to :class:`DMatrix
# X is a dataframe we created in previous snippet
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "gpu_hist"}, Xy)
booster = xgb.train({"tree_method": "hist", "max_cat_to_onehot": 5}, Xy)
# Must use JSON for serialization, otherwise the information is lost
booster.save_model("categorical-model.json")
Expand All @@ -109,30 +135,7 @@ types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatr
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
:class:`dask.Array <dask.Array>` can also be used as categorical data.

********************
Optimal Partitioning
********************

.. versionadded:: 1.6

Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details.
:class:`dask.Array <dask.Array>` can also be used for categorical data.

*************
Miscellaneous
Expand Down
10 changes: 10 additions & 0 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,16 @@ class RegTree : public Model {
*/
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
/*!
* \brief Get the bit storage for categories
*/
common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
auto node_ptr = GetCategoriesMatrix().node_ptr;
auto categories = GetCategoriesMatrix().categories;
auto segment = node_ptr[nidx];
auto node_cats = categories.subspan(segment.beg, segment.size);
return node_cats;
}
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }

// The fields of split_categories_segments_[i] are set such that
Expand Down
7 changes: 4 additions & 3 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,11 @@ def __init__(
.. versionadded:: 1.3.0
.. note:: This parameter is experimental
Experimental support of specializing for categorical features. Do not set
to True unless you are interested in development. Currently it's only
available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON
serialization format is required. (XGBoost 1.6 for approx)
to True unless you are interested in development. Also, JSON/UBJSON
serialization format is required.
"""
if group is not None and qid is not None:
Expand Down
17 changes: 9 additions & 8 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
.. versionadded:: 1.5.0
Experimental support for categorical data. Do not set to true unless you are
interested in development. Only valid when `gpu_hist` or `approx` is used along
with dataframe as input. Also, JSON/UBJSON serialization format is
required. (XGBoost 1.6 for approx)
.. note:: This parameter is experimental
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame
should be used to specify categorical data type. Also, JSON/UBJSON
serialization format is required.
max_cat_to_onehot : Optional[int]
Expand All @@ -220,9 +221,8 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold
then one-hot encoding is chosen, otherwise the categories will be partitioned
into children nodes. Only relevant for regression and binary
classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See
:doc:`Categorical Data </tutorials/categorical>` for details.
into children nodes. Only relevant for regression and binary classification.
See :doc:`Categorical Data </tutorials/categorical>` for details.
eval_metric : Optional[Union[str, List[str], Callable]]
Expand Down Expand Up @@ -846,7 +846,8 @@ def _duplicated(parameter: str) -> None:
callbacks = self.callbacks if self.callbacks is not None else callbacks

tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method not in ("gpu_hist", "approx"):
cat_support = {"gpu_hist", "approx", "hist"}
if self.enable_categorical and tree_method not in cat_support:
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
Expand Down
88 changes: 61 additions & 27 deletions src/common/partition_builder.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by Contributors
* Copyright 2021-2022 by Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
Expand All @@ -8,12 +8,15 @@
#define XGBOOST_COMMON_PARTITION_BUILDER_H_

#include <xgboost/data.h>

#include <algorithm>
#include <vector>
#include <utility>
#include <memory>
#include <utility>
#include <vector>

#include "categorical.h"
#include "column_matrix.h"
#include "xgboost/tree_model.h"
#include "../common/column_matrix.h"

namespace xgboost {
namespace common {
Expand Down Expand Up @@ -46,26 +49,32 @@ class PartitionBuilder {
// on comparison of indexes values (idx_span) and split point (split_cond)
// Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType>
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
common::Span<const size_t> rid_span, const int32_t split_cond,
common::Span<size_t> left_part, common::Span<size_t> right_part) {
common::Span<const size_t> row_indices,
common::Span<size_t> left_part,
common::Span<size_t> right_part,
size_t base_rowid, Predicate&& pred) {
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
size_t nleft_elems = 0;
size_t nright_elems = 0;
auto state = column.GetInitialState(rid_span.front());
auto state = column.GetInitialState(row_indices.front() - base_rowid);

auto p_row_indices = row_indices.data();
auto n_samples = row_indices.size();

for (auto rid : rid_span) {
const int32_t bin_id = column.GetBinIdx(rid, &state);
for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i];
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) {
p_left_part[nleft_elems++] = rid;
} else {
p_right_part[nright_elems++] = rid;
}
} else {
if (bin_id <= split_cond) {
if (pred(rid, bin_id)) {
p_left_part[nleft_elems++] = rid;
} else {
p_right_part[nright_elems++] = rid;
Expand Down Expand Up @@ -95,41 +104,66 @@ class PartitionBuilder {
return {nleft_elems, nright_elems};
}

template <typename BinIdxType, bool any_missing>
template <typename BinIdxType, bool any_missing, bool any_cat>
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
const int32_t split_cond, GHistIndexMatrix const& gmat,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set,
range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set,
range.begin(), range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
const bst_uint fid = tree[nid].SplitIndex();
const bool default_left = tree[nid].DefaultLeft();
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);

std::pair<size_t, size_t> child_nodes_sizes;
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);

auto const& index = gmat.index;
auto const& cut_values = gmat.cut.Values();
auto const& cut_ptrs = gmat.cut.Ptrs();

auto pred = [&](auto ridx, auto bin_id) {
if (any_cat && is_cat) {
auto begin = gmat.RowIdx(ridx);
auto end = gmat.RowIdx(ridx + 1);
auto f_begin = cut_ptrs[fid];
auto f_end = cut_ptrs[fid + 1];
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
// features.
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
bool go_left;
if (gidx == -1) {
go_left = default_left;
} else {
go_left = Decision(node_cats, cut_values[gidx], default_left);
}
return go_left;
} else {
return bin_id <= split_cond;
}
};

std::pair<size_t, size_t> child_nodes_sizes;
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
const common::DenseColumn<BinIdxType, any_missing>& column =
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
}
} else {
CHECK_EQ(any_missing, true);
const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
gmat.base_rowid, pred);
}
}

Expand Down
3 changes: 0 additions & 3 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,6 @@ class MemStackAllocator {
T& operator[](size_t i) { return ptr_[i]; }
T const& operator[](size_t i) const { return ptr_[i]; }

// FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist.
auto Get() { return ptr_; }

private:
T* ptr_ = nullptr;
size_t required_size_;
Expand Down
Loading

0 comments on commit 83a66b4

Please sign in to comment.