Skip to content

Commit

Permalink
Support custom metric as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 2, 2023
1 parent 2a48072 commit ee85812
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 8 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,7 @@ def apply(
return super().apply(X, ntree_limit, iteration_range)

def score(self, X: ArrayLike, y: ArrayLike) -> float:
"""Evaluate score for data using the first evaluation metric.
"""Evaluate score for data using the last evaluation metric.
Parameters
----------
Expand All @@ -2126,6 +2126,11 @@ def score(self, X: ArrayLike, y: ArrayLike) -> float:
"""
X, qid = _get_qid(X, None)
Xyq = DMatrix(X, y, qid=qid)
result_str = self.get_booster().eval(Xyq)
if callable(self.eval_metric):
metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric)
else:
result_str = self.get_booster().eval(Xyq)

metric_score = _parse_eval_str(result_str)
return metric_score[0][1]
return metric_score[-1][1]
14 changes: 12 additions & 2 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
import random
import tempfile
from typing import Callable, Optional
from typing import Any, Callable, Optional

import numpy as np
import pytest
Expand Down Expand Up @@ -185,6 +185,7 @@ def test_ranking_qid_df():
import pandas as pd
import scipy.sparse
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
from sklearn.metrics import mean_squared_error

X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3)

Expand Down Expand Up @@ -212,13 +213,22 @@ def test_ranking_qid_df():
results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid)
assert len(results) == 5

# Works with custom metric
def neg_mse(*args: Any, **kwargs: Any) -> float:
return -mean_squared_error(*args, **kwargs)

ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse)
ranker.fit(df, y, eval_set=[(valid_df, y)])
score = ranker.score(valid_df, y)
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])

# Works with sparse data
X_csr = scipy.sparse.csr_matrix(X)
df = pd.DataFrame.sparse.from_spmatrix(
X_csr, columns=[str(i) for i in range(X.shape[1])]
)
df["qid"] = q
ranker = xgb.XGBRanker(n_estimators=3)
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg")
ranker.fit(df, y)
s2 = ranker.score(df, y)
assert np.isclose(s2, s)
Expand Down

0 comments on commit ee85812

Please sign in to comment.