From edd8b94501db67556a9185949d439e427746f133 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 03:45:02 +0800 Subject: [PATCH 1/8] Support sklearn cross validation for ranker. - Add a convention for X to include a special `qid` column. sklearn utilities consider only `X`, `y` and `sample_weight` for supervised learning algorithms, but we need an additional qid array for ranking. It's important to be able to support the cross validation function in sklearn since all other tuning functions like grid search are based on cross validation. --- python-package/xgboost/callback.py | 14 +-- python-package/xgboost/core.py | 10 ++ python-package/xgboost/sklearn.py | 155 ++++++++++++++++++++++------- tests/python/test_with_sklearn.py | 46 +++++++++ 4 files changed, 184 insertions(+), 41 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 76350d839dd1..5be6a058ac8e 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -23,7 +23,13 @@ import numpy from . import collective -from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees +from .core import ( + Booster, + DMatrix, + XGBoostError, + _get_booster_layer_trees, + _parse_eval_str, +) __all__ = [ "TrainingCallback", @@ -250,11 +256,7 @@ def after_iteration( for _, name in evals: assert name.find("-") == -1, "Dataset name should not contain `-`" score: str = model.eval_set(evals, epoch, self.metric, self._output_margin) - splited = score.split()[1:] # into datasets - # split up `test-error:0.1234` - metric_score_str = [tuple(s.split(":")) for s in splited] - # convert to float - metric_score = [(n, float(s)) for n, s in metric_score_str] + metric_score = _parse_eval_str(score) self._update_history(metric_score, epoch) ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks) return ret diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a186dc3963dc..5a0cfb3a2ece 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -111,6 +111,16 @@ def make_jcargs(**kwargs: Any) -> bytes: return from_pystr_to_cstr(json.dumps(kwargs)) +def _parse_eval_str(result: str) -> List[Tuple[str, float]]: + """Parse an eval result string from the booster.""" + splited = result.split()[1:] + # split up `test-error:0.1234` + metric_score_str = [tuple(s.split(":")) for s in splited] + # convert to float + metric_score = [(n, float(s)) for n, s in metric_score_str] + return metric_score + + IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int]) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 69bcac38d01a..3d947de70677 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -43,8 +43,9 @@ XGBoostError, _convert_ntree_limit, _deprecate_positional_args, + _parse_eval_str, ) -from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array +from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df from .training import train @@ -1812,32 +1813,44 @@ def fit( return self +def _get_qid( + X: ArrayLike, qid: Optional[ArrayLike] +) -> Tuple[ArrayLike, Optional[ArrayLike]]: + """Get the special qid column from X if exists.""" + if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"): + if qid is not None: + raise ValueError( + "Found both the special column `qid` in `X` and the `qid` from the" + "`fit` method. Please remove one of them." + ) + assert qid is None + q_x = X.qid + X = X.drop("qid", axis=1) + return X, q_x + return X, qid + + @xgboost_model_doc( - "Implementation of the Scikit-Learn API for XGBoost Ranking.", + """Implementation of the Scikit-Learn API for XGBoost Ranking.""", ["estimators", "model"], end_note=""" - .. note:: - - The default objective for XGBRanker is "rank:pairwise" - .. note:: A custom objective function is currently not supported by XGBRanker. - Likewise, a custom metric function is not supported either. .. note:: - Query group information is required for ranking tasks by either using the - `group` parameter or `qid` parameter in `fit` method. This information is - not required in 'predict' method and multiple groups can be predicted on - a single call to `predict`. + Query group information is only required for ranking training but not + prediction. Multiple groups can be predicted on a single call to + :py:meth:`predict`. When fitting the model with the `group` parameter, your data need to be sorted - by query group first. `group` must be an array that contains the size of each + by the query group first. `group` is an array that contains the size of each query group. - When fitting the model with the `qid` parameter, your data does not need - sorting. `qid` must be an array that contains the group of each training - sample. + + Similarly, when fitting the model with the `qid` parameter, the data should be + sorted according to query index and `qid` is an array that contains the query + index for each training sample. For example, if your original data look like: @@ -1859,9 +1872,10 @@ def fit( | 2 | 1 | x_7 | +-------+-----------+---------------+ - then `fit` method can be called with either `group` array as ``[3, 4]`` - or with `qid` as ``[`1, 1, 1, 2, 2, 2, 2]``, that is the qid column. -""", + then :py:meth:`fit` method can be called with either `group` array as ``[3, 4]`` + or with `qid` as ``[1, 1, 1, 2, 2, 2, 2]``, that is the qid column. Also, the + `qid` can be a special column of input `X` instead of a separated parameter, see + :py:meth:`fit` for more info.""", ) class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @@ -1873,6 +1887,16 @@ def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): if "rank:" not in objective: raise ValueError("please use XGBRanker for ranking task") + def _create_ltr_dmatrix( + self, ref: Optional[DMatrix], data: ArrayLike, qid: ArrayLike, **kwargs: Any + ) -> DMatrix: + data, qid = _get_qid(data, qid) + + if kwargs.get("group", None) is None and qid is None: + raise ValueError("Either `group` or `qid` is required for ranking task") + + return super()._create_dmatrix(ref=ref, data=data, qid=qid, **kwargs) + @_deprecate_positional_args def fit( self, @@ -1907,6 +1931,23 @@ def fit( X : Feature matrix. See :ref:`py-data` for a list of supported types. + When this is a :py:class:`pandas.DataFrame` or a :py:class:`cudf.DataFrame`, + it may contain a special column called ``qid`` for specifying the query + index. Using a special column is the same as using the `qid` parameter, + except for being compatible with sklearn utility functions like + :py:func:`sklearn.model_selection.cross_validation`. The same convention + applies to the :py:meth:`XGBRanker.score` and :py:meth:`XGBRanker.predict`. + + +-----+----------------+----------------+ + | qid | feat_0 | feat_1 | + +-----+----------------+----------------+ + | 0 | :math:`x_{00}` | :math:`x_{01}` | + +-----+----------------+----------------+ + | 1 | :math:`x_{10}` | :math:`x_{11}` | + +-----+----------------+----------------+ + | 1 | :math:`x_{20}` | :math:`x_{21}` | + +-----+----------------+----------------+ + When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` for conserving memory. However, this has performance implications when the @@ -1916,12 +1957,12 @@ def fit( y : Labels group : - Size of each query group of training data. Should have as many elements as the - query groups in the training data. If this is set to None, then user must - provide qid. + Size of each query group of training data. Should have as many elements as + the query groups in the training data. If this is set to None, then user + must provide qid. qid : Query ID for each training sample. Should have the size of n_samples. If - this is set to None, then user must provide group. + this is set to None, then user must provide group or a special column in X. sample_weight : Query group weights @@ -1929,8 +1970,9 @@ def fit( In ranking task, one weight is assigned to each query group/id (not each data point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign weights - to individual data points. + data points within each group, so it doesn't make sense to assign + weights to individual data points. + base_margin : Global bias for each instance. eval_set : @@ -1942,7 +1984,8 @@ def fit( query groups in the ``i``-th pair in **eval_set**. eval_qid : A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th - pair in **eval_set**. + pair in **eval_set**. The special column convention in `X` applies to + validation datasets as well. eval_metric : str, list of str, optional .. deprecated:: 1.6.0 @@ -1985,16 +2028,7 @@ def fit( Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead. """ - # check if group information is provided with config_context(verbosity=self.verbosity): - if group is None and qid is None: - raise ValueError("group or qid is required for ranking task") - - if eval_set is not None: - if eval_group is None and eval_qid is None: - raise ValueError( - "eval_group or eval_qid is required if eval_set is not None" - ) train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, X=X, @@ -2009,7 +2043,7 @@ def fit( base_margin_eval_set=base_margin_eval_set, eval_group=eval_group, eval_qid=eval_qid, - create_dmatrix=self._create_dmatrix, + create_dmatrix=self._create_ltr_dmatrix, enable_categorical=self.enable_categorical, feature_types=self.feature_types, ) @@ -2044,3 +2078,54 @@ def fit( self._set_evaluation_result(evals_result) return self + + def predict( + self, + X: ArrayLike, + output_margin: bool = False, + ntree_limit: Optional[int] = None, + validate_features: bool = True, + base_margin: Optional[ArrayLike] = None, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> ArrayLike: + X, _ = _get_qid(X, None) + return super().predict( + X, + output_margin, + ntree_limit, + validate_features, + base_margin, + iteration_range, + ) + + def apply( + self, + X: ArrayLike, + ntree_limit: int = 0, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> ArrayLike: + X, _ = _get_qid(X, None) + return super().apply(X, ntree_limit, iteration_range) + + def score(self, X: ArrayLike, y: ArrayLike) -> float: + """Evaluate score for data using the first evaluation metric. + + Parameters + ---------- + X : pd.DataFrame|cudf.DataFrame + A DataFrame with a special `qid` column. + + y : ArrayLike + Dependent variable. + + Returns + ------- + score : + The result of the first evaluation metric for the ranker. + + """ + X, qid = _get_qid(X, None) + Xyq = DMatrix(X, y, qid=qid) + result_str = self.get_booster().eval(Xyq) + metric_score = _parse_eval_str(result_str) + return metric_score[0][1] diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index bc7a3e94e437..ec1763a97788 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -180,6 +180,52 @@ def test_ranking_metric() -> None: assert results["validation_0"]["roc_auc_score"][-1] > 0.6 +@pytest.mark.skipif(**tm.no_pandas()) +def test_ranking_qid_df(): + import pandas as pd + import scipy.sparse + from sklearn.model_selection import cross_val_score + + X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=3, max_rel=3) + + # pack qid into x using dataframe + df = pd.DataFrame(X) + df["qid"] = q + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg") + ranker.fit(df, y) + s = ranker.score(df, y) + assert s > 0.7 + + # works with validation datasets as well + valid_df = df.copy() + valid_df.iloc[0, 0] = 3.0 + ranker.fit(df, y, eval_set=[(valid_df, y)]) + + # same as passing qid directly + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg") + ranker.fit(X, y, qid=q) + s1 = ranker.score(df, y) + assert np.isclose(s, s1) + + # Works with sparse data + X_csr = scipy.sparse.csr_matrix(X) + df = pd.DataFrame.sparse.from_spmatrix( + X_csr, columns=[str(i) for i in range(X.shape[1])] + ) + df["qid"] = q + ranker = xgb.XGBRanker(n_estimators=3) + ranker.fit(df, y) + s2 = ranker.score(df, y) + assert np.isclose(s2, s) + + # Works with standard sklearn cv + results = cross_val_score(ranker, df, y) + assert len(results) == 5 + + with pytest.raises(ValueError, match="Either `group` or `qid`."): + ranker.fit(df, y, eval_set=[(X, y)]) + + def test_stacking_regression(): from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor, StackingRegressor From 593cecb2ee75ea9104a10731c6e94363c046477d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 03:53:21 +0800 Subject: [PATCH 2/8] consistent document. --- python-package/xgboost/sklearn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 3d947de70677..92ae43509367 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -2113,10 +2113,10 @@ def score(self, X: ArrayLike, y: ArrayLike) -> float: Parameters ---------- X : pd.DataFrame|cudf.DataFrame - A DataFrame with a special `qid` column. + Feature matrix. A DataFrame with a special `qid` column. - y : ArrayLike - Dependent variable. + y : + Labels Returns ------- From 2905cd1811ee1e611140d509c61c0c27b21f4ac2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 04:31:06 +0800 Subject: [PATCH 3/8] Update. --- tests/python/test_with_sklearn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index ec1763a97788..6312437cd25b 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -184,9 +184,9 @@ def test_ranking_metric() -> None: def test_ranking_qid_df(): import pandas as pd import scipy.sparse - from sklearn.model_selection import cross_val_score + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score - X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=3, max_rel=3) + X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) # pack qid into x using dataframe df = pd.DataFrame(X) @@ -207,6 +207,11 @@ def test_ranking_qid_df(): s1 = ranker.score(df, y) assert np.isclose(s, s1) + # Works with standard sklearn cv + kfold = StratifiedGroupKFold(shuffle=False) + results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + assert len(results) == 5 + # Works with sparse data X_csr = scipy.sparse.csr_matrix(X) df = pd.DataFrame.sparse.from_spmatrix( @@ -218,10 +223,6 @@ def test_ranking_qid_df(): s2 = ranker.score(df, y) assert np.isclose(s2, s) - # Works with standard sklearn cv - results = cross_val_score(ranker, df, y) - assert len(results) == 5 - with pytest.raises(ValueError, match="Either `group` or `qid`."): ranker.fit(df, y, eval_set=[(X, y)]) From f2f55eb77f5158954bce837025e2711eb7e97f2a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 09:47:36 +0800 Subject: [PATCH 4/8] Support custom metric as well. --- python-package/xgboost/sklearn.py | 11 ++++++++--- tests/python/test_with_sklearn.py | 14 ++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 92ae43509367..ad57aee8214c 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -2108,7 +2108,7 @@ def apply( return super().apply(X, ntree_limit, iteration_range) def score(self, X: ArrayLike, y: ArrayLike) -> float: - """Evaluate score for data using the first evaluation metric. + """Evaluate score for data using the last evaluation metric. Parameters ---------- @@ -2126,6 +2126,11 @@ def score(self, X: ArrayLike, y: ArrayLike) -> float: """ X, qid = _get_qid(X, None) Xyq = DMatrix(X, y, qid=qid) - result_str = self.get_booster().eval(Xyq) + if callable(self.eval_metric): + metric = ltr_metric_decorator(self.eval_metric, self.n_jobs) + result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric) + else: + result_str = self.get_booster().eval(Xyq) + metric_score = _parse_eval_str(result_str) - return metric_score[0][1] + return metric_score[-1][1] diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 6312437cd25b..6ab2024f6e24 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -3,7 +3,7 @@ import pickle import random import tempfile -from typing import Callable, Optional +from typing import Any, Callable, Optional import numpy as np import pytest @@ -185,6 +185,7 @@ def test_ranking_qid_df(): import pandas as pd import scipy.sparse from sklearn.model_selection import StratifiedGroupKFold, cross_val_score + from sklearn.metrics import mean_squared_error X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) @@ -212,13 +213,22 @@ def test_ranking_qid_df(): results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) assert len(results) == 5 + # Works with custom metric + def neg_mse(*args: Any, **kwargs: Any) -> float: + return -mean_squared_error(*args, **kwargs) + + ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse) + ranker.fit(df, y, eval_set=[(valid_df, y)]) + score = ranker.score(valid_df, y) + assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) + # Works with sparse data X_csr = scipy.sparse.csr_matrix(X) df = pd.DataFrame.sparse.from_spmatrix( X_csr, columns=[str(i) for i in range(X.shape[1])] ) df["qid"] = q - ranker = xgb.XGBRanker(n_estimators=3) + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg") ranker.fit(df, y) s2 = ranker.score(df, y) assert np.isclose(s2, s) From b335e13aa606f5195d446b7f1dbd3e9045092a69 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 11:16:00 +0800 Subject: [PATCH 5/8] lint. --- tests/python/test_with_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 6ab2024f6e24..0ba16dde2795 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -184,8 +184,8 @@ def test_ranking_metric() -> None: def test_ranking_qid_df(): import pandas as pd import scipy.sparse - from sklearn.model_selection import StratifiedGroupKFold, cross_val_score from sklearn.metrics import mean_squared_error + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) From 57215fcadc57e8d3dad2143ce520acf1d4d55b5a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 15:55:08 +0800 Subject: [PATCH 6/8] test cudf. --- python-package/xgboost/testing/ranking.py | 71 +++++++++++++++++++++++ tests/python-gpu/test_gpu_with_sklearn.py | 8 +++ tests/python/test_with_sklearn.py | 55 +----------------- 3 files changed, 82 insertions(+), 52 deletions(-) create mode 100644 python-package/xgboost/testing/ranking.py diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py new file mode 100644 index 000000000000..2dbf45f5dae0 --- /dev/null +++ b/python-package/xgboost/testing/ranking.py @@ -0,0 +1,71 @@ +"""Tests for learning to rank.""" +from types import ModuleType +from typing import Any + +import numpy as np +import pytest + +import xgboost as xgb +from xgboost import testing as tm + + +def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None: + """Test ranking with qid packed into X.""" + import scipy.sparse + from sklearn.metrics import mean_squared_error + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score + + X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) + + # pack qid into x using dataframe + df = impl.DataFrame(X) + df["qid"] = q + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method) + ranker.fit(df, y) + s = ranker.score(df, y) + assert s > 0.7 + + # works with validation datasets as well + valid_df = df.copy() + valid_df.iloc[0, 0] = 3.0 + ranker.fit(df, y, eval_set=[(valid_df, y)]) + + # same as passing qid directly + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method) + ranker.fit(X, y, qid=q) + s1 = ranker.score(df, y) + assert np.isclose(s, s1) + + # Works with standard sklearn cv + if tree_method != "gpu_hist": + # we need cuML for this. + kfold = StratifiedGroupKFold(shuffle=False) + results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + assert len(results) == 5 + + # Works with custom metric + def neg_mse(*args: Any, **kwargs: Any) -> float: + return -mean_squared_error(*args, **kwargs) + + ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method) + ranker.fit(df, y, eval_set=[(valid_df, y)]) + score = ranker.score(valid_df, y) + assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) + + # Works with sparse data + if tree_method != "gpu_hist": + # no sparse with cuDF + X_csr = scipy.sparse.csr_matrix(X) + df = impl.DataFrame.sparse.from_spmatrix( + X_csr, columns=[str(i) for i in range(X.shape[1])] + ) + df["qid"] = q + ranker = xgb.XGBRanker( + n_estimators=3, eval_metric="ndcg", tree_method=tree_method + ) + ranker.fit(df, y) + s2 = ranker.score(df, y) + assert np.isclose(s2, s) + + with pytest.raises(ValueError, match="Either `group` or `qid`."): + ranker.fit(df, y, eval_set=[(X, y)]) diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index 8ecb4bdc77cc..f26f70b367f4 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -5,6 +5,7 @@ import numpy as np import pytest +from xgboost.testing.ranking import run_ranking_qid_df import xgboost as xgb from xgboost import testing as tm @@ -153,3 +154,10 @@ def test_classififer(): y *= 10 with pytest.raises(ValueError, match=r"Invalid classes.*"): clf.fit(X, y) + + +@pytest.mark.skipif(**tm.no_pandas()) +def test_ranking_qid_df(): + import cudf + + run_ranking_qid_df(cudf, "gpu_hist") diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 0ba16dde2795..baef690ee32e 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -3,7 +3,7 @@ import pickle import random import tempfile -from typing import Any, Callable, Optional +from typing import Callable, Optional import numpy as np import pytest @@ -11,6 +11,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.ranking import run_ranking_qid_df from xgboost.testing.shared import get_feature_weights, validate_data_initialization from xgboost.testing.updater import get_basescore @@ -183,58 +184,8 @@ def test_ranking_metric() -> None: @pytest.mark.skipif(**tm.no_pandas()) def test_ranking_qid_df(): import pandas as pd - import scipy.sparse - from sklearn.metrics import mean_squared_error - from sklearn.model_selection import StratifiedGroupKFold, cross_val_score - - X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) - # pack qid into x using dataframe - df = pd.DataFrame(X) - df["qid"] = q - ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg") - ranker.fit(df, y) - s = ranker.score(df, y) - assert s > 0.7 - - # works with validation datasets as well - valid_df = df.copy() - valid_df.iloc[0, 0] = 3.0 - ranker.fit(df, y, eval_set=[(valid_df, y)]) - - # same as passing qid directly - ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg") - ranker.fit(X, y, qid=q) - s1 = ranker.score(df, y) - assert np.isclose(s, s1) - - # Works with standard sklearn cv - kfold = StratifiedGroupKFold(shuffle=False) - results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) - assert len(results) == 5 - - # Works with custom metric - def neg_mse(*args: Any, **kwargs: Any) -> float: - return -mean_squared_error(*args, **kwargs) - - ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse) - ranker.fit(df, y, eval_set=[(valid_df, y)]) - score = ranker.score(valid_df, y) - assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) - - # Works with sparse data - X_csr = scipy.sparse.csr_matrix(X) - df = pd.DataFrame.sparse.from_spmatrix( - X_csr, columns=[str(i) for i in range(X.shape[1])] - ) - df["qid"] = q - ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg") - ranker.fit(df, y) - s2 = ranker.score(df, y) - assert np.isclose(s2, s) - - with pytest.raises(ValueError, match="Either `group` or `qid`."): - ranker.fit(df, y, eval_set=[(X, y)]) + run_ranking_qid_df(pd, "hist") def test_stacking_regression(): From ad78621211c00fff871bdc96048d1f1485029fe6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 19:18:28 +0800 Subject: [PATCH 7/8] pylint. --- python-package/xgboost/collective.py | 2 +- python-package/xgboost/rabit.py | 2 +- python-package/xgboost/testing/ranking.py | 5 +++-- tests/python-gpu/test_gpu_with_sklearn.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 7c586cba71d3..4c67ccbfcad7 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -231,7 +231,7 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid if buf.base is data.base: buf = buf.copy() if buf.dtype not in DTYPE_ENUM__: - raise Exception(f"data type {buf.dtype} not supported") + raise TypeError(f"data type {buf.dtype} not supported") _check_call( _LIB.XGCommunicatorAllreduce( buf.ctypes.data_as(ctypes.c_void_p), diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index 0b8f143ecd35..132d721787b1 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -136,7 +136,7 @@ def allreduce( # pylint:disable=invalid-name """ if prepare_fun is None: return collective.allreduce(data, collective.Op(op)) - raise Exception("preprocessing function is no longer supported") + raise ValueError("preprocessing function is no longer supported") def version_number() -> int: diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py index 2dbf45f5dae0..fe4fc8404567 100644 --- a/python-package/xgboost/testing/ranking.py +++ b/python-package/xgboost/testing/ranking.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-locals """Tests for learning to rank.""" from types import ModuleType from typing import Any @@ -15,7 +16,7 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None: from sklearn.metrics import mean_squared_error from sklearn.model_selection import StratifiedGroupKFold, cross_val_score - X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) + X, y, q, _ = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) # pack qid into x using dataframe df = impl.DataFrame(X) @@ -45,7 +46,7 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None: # Works with custom metric def neg_mse(*args: Any, **kwargs: Any) -> float: - return -mean_squared_error(*args, **kwargs) + return -float(mean_squared_error(*args, **kwargs)) ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method) ranker.fit(df, y, eval_set=[(valid_df, y)]) diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index f26f70b367f4..c9d3ab4ebff7 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -5,10 +5,10 @@ import numpy as np import pytest -from xgboost.testing.ranking import run_ranking_qid_df import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.ranking import run_ranking_qid_df sys.path.append("tests/python") import test_with_sklearn as twskl # noqa From 44e9dd5cd3fbc2071d43402bba85331aad108bb0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 6 Mar 2023 23:10:22 +0800 Subject: [PATCH 8/8] duplicated check. --- python-package/xgboost/sklearn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index ad57aee8214c..3204f5a2a61e 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1823,7 +1823,6 @@ def _get_qid( "Found both the special column `qid` in `X` and the `qid` from the" "`fit` method. Please remove one of them." ) - assert qid is None q_x = X.qid X = X.drop("qid", axis=1) return X, q_x