diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 76350d839dd1..5be6a058ac8e 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -23,7 +23,13 @@ import numpy from . import collective -from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees +from .core import ( + Booster, + DMatrix, + XGBoostError, + _get_booster_layer_trees, + _parse_eval_str, +) __all__ = [ "TrainingCallback", @@ -250,11 +256,7 @@ def after_iteration( for _, name in evals: assert name.find("-") == -1, "Dataset name should not contain `-`" score: str = model.eval_set(evals, epoch, self.metric, self._output_margin) - splited = score.split()[1:] # into datasets - # split up `test-error:0.1234` - metric_score_str = [tuple(s.split(":")) for s in splited] - # convert to float - metric_score = [(n, float(s)) for n, s in metric_score_str] + metric_score = _parse_eval_str(score) self._update_history(metric_score, epoch) ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks) return ret diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 7c586cba71d3..4c67ccbfcad7 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -231,7 +231,7 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid if buf.base is data.base: buf = buf.copy() if buf.dtype not in DTYPE_ENUM__: - raise Exception(f"data type {buf.dtype} not supported") + raise TypeError(f"data type {buf.dtype} not supported") _check_call( _LIB.XGCommunicatorAllreduce( buf.ctypes.data_as(ctypes.c_void_p), diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a186dc3963dc..5a0cfb3a2ece 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -111,6 +111,16 @@ def make_jcargs(**kwargs: Any) -> bytes: return from_pystr_to_cstr(json.dumps(kwargs)) +def _parse_eval_str(result: str) -> List[Tuple[str, float]]: + """Parse an eval result string from the booster.""" + splited = result.split()[1:] + # split up `test-error:0.1234` + metric_score_str = [tuple(s.split(":")) for s in splited] + # convert to float + metric_score = [(n, float(s)) for n, s in metric_score_str] + return metric_score + + IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int]) diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index 0b8f143ecd35..132d721787b1 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -136,7 +136,7 @@ def allreduce( # pylint:disable=invalid-name """ if prepare_fun is None: return collective.allreduce(data, collective.Op(op)) - raise Exception("preprocessing function is no longer supported") + raise ValueError("preprocessing function is no longer supported") def version_number() -> int: diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 69bcac38d01a..3204f5a2a61e 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -43,8 +43,9 @@ XGBoostError, _convert_ntree_limit, _deprecate_positional_args, + _parse_eval_str, ) -from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array +from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df from .training import train @@ -1812,32 +1813,43 @@ def fit( return self +def _get_qid( + X: ArrayLike, qid: Optional[ArrayLike] +) -> Tuple[ArrayLike, Optional[ArrayLike]]: + """Get the special qid column from X if exists.""" + if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"): + if qid is not None: + raise ValueError( + "Found both the special column `qid` in `X` and the `qid` from the" + "`fit` method. Please remove one of them." + ) + q_x = X.qid + X = X.drop("qid", axis=1) + return X, q_x + return X, qid + + @xgboost_model_doc( - "Implementation of the Scikit-Learn API for XGBoost Ranking.", + """Implementation of the Scikit-Learn API for XGBoost Ranking.""", ["estimators", "model"], end_note=""" - .. note:: - - The default objective for XGBRanker is "rank:pairwise" - .. note:: A custom objective function is currently not supported by XGBRanker. - Likewise, a custom metric function is not supported either. .. note:: - Query group information is required for ranking tasks by either using the - `group` parameter or `qid` parameter in `fit` method. This information is - not required in 'predict' method and multiple groups can be predicted on - a single call to `predict`. + Query group information is only required for ranking training but not + prediction. Multiple groups can be predicted on a single call to + :py:meth:`predict`. When fitting the model with the `group` parameter, your data need to be sorted - by query group first. `group` must be an array that contains the size of each + by the query group first. `group` is an array that contains the size of each query group. - When fitting the model with the `qid` parameter, your data does not need - sorting. `qid` must be an array that contains the group of each training - sample. + + Similarly, when fitting the model with the `qid` parameter, the data should be + sorted according to query index and `qid` is an array that contains the query + index for each training sample. For example, if your original data look like: @@ -1859,9 +1871,10 @@ def fit( | 2 | 1 | x_7 | +-------+-----------+---------------+ - then `fit` method can be called with either `group` array as ``[3, 4]`` - or with `qid` as ``[`1, 1, 1, 2, 2, 2, 2]``, that is the qid column. -""", + then :py:meth:`fit` method can be called with either `group` array as ``[3, 4]`` + or with `qid` as ``[1, 1, 1, 2, 2, 2, 2]``, that is the qid column. Also, the + `qid` can be a special column of input `X` instead of a separated parameter, see + :py:meth:`fit` for more info.""", ) class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @@ -1873,6 +1886,16 @@ def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): if "rank:" not in objective: raise ValueError("please use XGBRanker for ranking task") + def _create_ltr_dmatrix( + self, ref: Optional[DMatrix], data: ArrayLike, qid: ArrayLike, **kwargs: Any + ) -> DMatrix: + data, qid = _get_qid(data, qid) + + if kwargs.get("group", None) is None and qid is None: + raise ValueError("Either `group` or `qid` is required for ranking task") + + return super()._create_dmatrix(ref=ref, data=data, qid=qid, **kwargs) + @_deprecate_positional_args def fit( self, @@ -1907,6 +1930,23 @@ def fit( X : Feature matrix. See :ref:`py-data` for a list of supported types. + When this is a :py:class:`pandas.DataFrame` or a :py:class:`cudf.DataFrame`, + it may contain a special column called ``qid`` for specifying the query + index. Using a special column is the same as using the `qid` parameter, + except for being compatible with sklearn utility functions like + :py:func:`sklearn.model_selection.cross_validation`. The same convention + applies to the :py:meth:`XGBRanker.score` and :py:meth:`XGBRanker.predict`. + + +-----+----------------+----------------+ + | qid | feat_0 | feat_1 | + +-----+----------------+----------------+ + | 0 | :math:`x_{00}` | :math:`x_{01}` | + +-----+----------------+----------------+ + | 1 | :math:`x_{10}` | :math:`x_{11}` | + +-----+----------------+----------------+ + | 1 | :math:`x_{20}` | :math:`x_{21}` | + +-----+----------------+----------------+ + When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` for conserving memory. However, this has performance implications when the @@ -1916,12 +1956,12 @@ def fit( y : Labels group : - Size of each query group of training data. Should have as many elements as the - query groups in the training data. If this is set to None, then user must - provide qid. + Size of each query group of training data. Should have as many elements as + the query groups in the training data. If this is set to None, then user + must provide qid. qid : Query ID for each training sample. Should have the size of n_samples. If - this is set to None, then user must provide group. + this is set to None, then user must provide group or a special column in X. sample_weight : Query group weights @@ -1929,8 +1969,9 @@ def fit( In ranking task, one weight is assigned to each query group/id (not each data point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign weights - to individual data points. + data points within each group, so it doesn't make sense to assign + weights to individual data points. + base_margin : Global bias for each instance. eval_set : @@ -1942,7 +1983,8 @@ def fit( query groups in the ``i``-th pair in **eval_set**. eval_qid : A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th - pair in **eval_set**. + pair in **eval_set**. The special column convention in `X` applies to + validation datasets as well. eval_metric : str, list of str, optional .. deprecated:: 1.6.0 @@ -1985,16 +2027,7 @@ def fit( Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead. """ - # check if group information is provided with config_context(verbosity=self.verbosity): - if group is None and qid is None: - raise ValueError("group or qid is required for ranking task") - - if eval_set is not None: - if eval_group is None and eval_qid is None: - raise ValueError( - "eval_group or eval_qid is required if eval_set is not None" - ) train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, X=X, @@ -2009,7 +2042,7 @@ def fit( base_margin_eval_set=base_margin_eval_set, eval_group=eval_group, eval_qid=eval_qid, - create_dmatrix=self._create_dmatrix, + create_dmatrix=self._create_ltr_dmatrix, enable_categorical=self.enable_categorical, feature_types=self.feature_types, ) @@ -2044,3 +2077,59 @@ def fit( self._set_evaluation_result(evals_result) return self + + def predict( + self, + X: ArrayLike, + output_margin: bool = False, + ntree_limit: Optional[int] = None, + validate_features: bool = True, + base_margin: Optional[ArrayLike] = None, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> ArrayLike: + X, _ = _get_qid(X, None) + return super().predict( + X, + output_margin, + ntree_limit, + validate_features, + base_margin, + iteration_range, + ) + + def apply( + self, + X: ArrayLike, + ntree_limit: int = 0, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> ArrayLike: + X, _ = _get_qid(X, None) + return super().apply(X, ntree_limit, iteration_range) + + def score(self, X: ArrayLike, y: ArrayLike) -> float: + """Evaluate score for data using the last evaluation metric. + + Parameters + ---------- + X : pd.DataFrame|cudf.DataFrame + Feature matrix. A DataFrame with a special `qid` column. + + y : + Labels + + Returns + ------- + score : + The result of the first evaluation metric for the ranker. + + """ + X, qid = _get_qid(X, None) + Xyq = DMatrix(X, y, qid=qid) + if callable(self.eval_metric): + metric = ltr_metric_decorator(self.eval_metric, self.n_jobs) + result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric) + else: + result_str = self.get_booster().eval(Xyq) + + metric_score = _parse_eval_str(result_str) + return metric_score[-1][1] diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py new file mode 100644 index 000000000000..fe4fc8404567 --- /dev/null +++ b/python-package/xgboost/testing/ranking.py @@ -0,0 +1,72 @@ +# pylint: disable=too-many-locals +"""Tests for learning to rank.""" +from types import ModuleType +from typing import Any + +import numpy as np +import pytest + +import xgboost as xgb +from xgboost import testing as tm + + +def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None: + """Test ranking with qid packed into X.""" + import scipy.sparse + from sklearn.metrics import mean_squared_error + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score + + X, y, q, _ = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) + + # pack qid into x using dataframe + df = impl.DataFrame(X) + df["qid"] = q + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method) + ranker.fit(df, y) + s = ranker.score(df, y) + assert s > 0.7 + + # works with validation datasets as well + valid_df = df.copy() + valid_df.iloc[0, 0] = 3.0 + ranker.fit(df, y, eval_set=[(valid_df, y)]) + + # same as passing qid directly + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method) + ranker.fit(X, y, qid=q) + s1 = ranker.score(df, y) + assert np.isclose(s, s1) + + # Works with standard sklearn cv + if tree_method != "gpu_hist": + # we need cuML for this. + kfold = StratifiedGroupKFold(shuffle=False) + results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + assert len(results) == 5 + + # Works with custom metric + def neg_mse(*args: Any, **kwargs: Any) -> float: + return -float(mean_squared_error(*args, **kwargs)) + + ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method) + ranker.fit(df, y, eval_set=[(valid_df, y)]) + score = ranker.score(valid_df, y) + assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) + + # Works with sparse data + if tree_method != "gpu_hist": + # no sparse with cuDF + X_csr = scipy.sparse.csr_matrix(X) + df = impl.DataFrame.sparse.from_spmatrix( + X_csr, columns=[str(i) for i in range(X.shape[1])] + ) + df["qid"] = q + ranker = xgb.XGBRanker( + n_estimators=3, eval_metric="ndcg", tree_method=tree_method + ) + ranker.fit(df, y) + s2 = ranker.score(df, y) + assert np.isclose(s2, s) + + with pytest.raises(ValueError, match="Either `group` or `qid`."): + ranker.fit(df, y, eval_set=[(X, y)]) diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index 8ecb4bdc77cc..c9d3ab4ebff7 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -8,6 +8,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.ranking import run_ranking_qid_df sys.path.append("tests/python") import test_with_sklearn as twskl # noqa @@ -153,3 +154,10 @@ def test_classififer(): y *= 10 with pytest.raises(ValueError, match=r"Invalid classes.*"): clf.fit(X, y) + + +@pytest.mark.skipif(**tm.no_pandas()) +def test_ranking_qid_df(): + import cudf + + run_ranking_qid_df(cudf, "gpu_hist") diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index bc7a3e94e437..baef690ee32e 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -11,6 +11,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.ranking import run_ranking_qid_df from xgboost.testing.shared import get_feature_weights, validate_data_initialization from xgboost.testing.updater import get_basescore @@ -180,6 +181,13 @@ def test_ranking_metric() -> None: assert results["validation_0"]["roc_auc_score"][-1] > 0.6 +@pytest.mark.skipif(**tm.no_pandas()) +def test_ranking_qid_df(): + import pandas as pd + + run_ranking_qid_df(pd, "hist") + + def test_stacking_regression(): from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor, StackingRegressor