Skip to content

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 1, 2023
1 parent c67de49 commit 2a48072
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def test_ranking_metric() -> None:
def test_ranking_qid_df():
import pandas as pd
import scipy.sparse
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score

X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=3, max_rel=3)
X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3)

# pack qid into x using dataframe
df = pd.DataFrame(X)
Expand All @@ -207,6 +207,11 @@ def test_ranking_qid_df():
s1 = ranker.score(df, y)
assert np.isclose(s, s1)

# Works with standard sklearn cv
kfold = StratifiedGroupKFold(shuffle=False)
results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid)
assert len(results) == 5

# Works with sparse data
X_csr = scipy.sparse.csr_matrix(X)
df = pd.DataFrame.sparse.from_spmatrix(
Expand All @@ -218,10 +223,6 @@ def test_ranking_qid_df():
s2 = ranker.score(df, y)
assert np.isclose(s2, s)

# Works with standard sklearn cv
results = cross_val_score(ranker, df, y)
assert len(results) == 5

with pytest.raises(ValueError, match="Either `group` or `qid`."):
ranker.fit(df, y, eval_set=[(X, y)])

Expand Down

0 comments on commit 2a48072

Please sign in to comment.