From c4b9f4f6229ca4bc5f9d42695fc41ea2f177b6bf Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 4 Jun 2021 02:29:14 +0800 Subject: [PATCH] Add `enable_categorical` to sklearn. (#7011) --- python-package/xgboost/dask.py | 3 ++ python-package/xgboost/sklearn.py | 20 ++++++++++++- tests/python-gpu/test_gpu_with_sklearn.py | 36 +++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index d38ebbc50040..a07a61224d34 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1642,6 +1642,7 @@ async def _fit_async( eval_group=None, eval_qid=None, missing=self.missing, + enable_categorical=self.enable_categorical, ) if callable(self.objective): @@ -1730,6 +1731,7 @@ async def _fit_async( eval_group=None, eval_qid=None, missing=self.missing, + enable_categorical=self.enable_categorical, ) # pylint: disable=attribute-defined-outside-init @@ -1927,6 +1929,7 @@ async def _fit_async( eval_group=None, eval_qid=eval_qid, missing=self.missing, + enable_categorical=self.enable_categorical, ) if eval_metric is not None: if callable(eval_metric): diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 4b06d92f121e..d9b928551c48 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -164,6 +164,14 @@ def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]: validate_parameters : Optional[bool] Give warnings for unknown parameter. + enable_categorical : bool + + .. versionadded:: 1.5.0 + + Experimental support for categorical data. Do not set to true unless you are + interested in development. Only valid when `gpu_hist` and pandas dataframe are + used. + kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here: @@ -257,6 +265,7 @@ def _wrap_evaluation_matrices( eval_group: Optional[List[Any]], eval_qid: Optional[List[Any]], create_dmatrix: Callable, + enable_categorical: bool, label_transform: Callable = lambda x: x, ) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. @@ -271,6 +280,7 @@ def _wrap_evaluation_matrices( base_margin=base_margin, feature_weights=feature_weights, missing=missing, + enable_categorical=enable_categorical, ) n_validation = 0 if eval_set is None else len(eval_set) @@ -317,6 +327,7 @@ def validate_or_none(meta: Optional[List], name: str) -> List: qid=eval_qid[i], base_margin=base_margin_eval_set[i], missing=missing, + enable_categorical=enable_categorical, ) evals.append(m) nevals = len(evals) @@ -375,6 +386,7 @@ def __init__( gpu_id: Optional[int] = None, validate_parameters: Optional[bool] = None, predictor: Optional[str] = None, + enable_categorical: bool = False, **kwargs: Any ) -> None: if not SKLEARN_INSTALLED: @@ -411,6 +423,7 @@ def __init__( self.gpu_id = gpu_id self.validate_parameters = validate_parameters self.predictor = predictor + self.enable_categorical = enable_categorical def _more_tags(self) -> Dict[str, bool]: '''Tags used for scikit-learn data validation.''' @@ -514,7 +527,9 @@ def get_xgb_params(self) -> Dict[str, Any]: params = self.get_params() # Parameters that should not go into native learner. wrapper_specific = { - 'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'} + 'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder', + "enable_categorical" + } filtered = dict() for k, v in params.items(): if k not in wrapper_specific and not callable(v): @@ -735,6 +750,7 @@ def fit( eval_group=None, eval_qid=None, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + enable_categorical=self.enable_categorical, ) params = self.get_xgb_params() @@ -1202,6 +1218,7 @@ def fit( eval_group=None, eval_qid=None, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + enable_categorical=self.enable_categorical, label_transform=label_transform, ) @@ -1628,6 +1645,7 @@ def fit( eval_group=eval_group, eval_qid=eval_qid, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + enable_categorical=self.enable_categorical, ) evals_result: TrainingCallback.EvalsLog = {} diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index 8f99cba97919..a38001a51610 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -1,7 +1,10 @@ +import json import xgboost as xgb import pytest +import tempfile import sys import numpy as np +import os sys.path.append("tests/python") import testing as tm # noqa @@ -38,3 +41,36 @@ def test_boost_from_prediction_gpu_hist(): def test_num_parallel_tree(): twskl.run_boston_housing_rf_regression("gpu_hist") + + +@pytest.mark.skipif(**tm.no_pandas()) +@pytest.mark.skipif(**tm.no_sklearn()) +def test_categorical(): + import pandas as pd + from sklearn.datasets import load_svmlight_file + + data_dir = os.path.join(tm.PROJECT_ROOT, "demo", "data") + X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train")) + clf = xgb.XGBClassifier( + tree_method="gpu_hist", + use_label_encoder=False, + enable_categorical=True, + predictor="gpu_predictor", + n_estimators=10, + ) + X = pd.DataFrame(X.todense()).astype("category") + clf.fit(X, y) + + with tempfile.TemporaryDirectory() as tempdir: + model = os.path.join(tempdir, "categorial.json") + clf.save_model(model) + + with open(model) as fd: + categorical = json.load(fd) + categories_sizes = np.array( + categorical["learner"]["gradient_booster"]["model"]["trees"][0][ + "categories_sizes" + ] + ) + assert categories_sizes.shape[0] != 0 + np.testing.assert_allclose(categories_sizes, 1)