Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] [ci] add support for scikit-learn 0.24+ in tests (fixes #4031) #4032

Merged
merged 8 commits into from
Mar 2, 2021
Merged
10 changes: 9 additions & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from pkg_resources import parse_version
from scipy.sparse import csr_matrix
from scipy.stats import spearmanr
from sklearn import __version__ as sk_version
from sklearn.datasets import make_blobs, make_regression

from .utils import make_ranking

sk_version = parse_version(sk_version)

# time, in seconds, to wait for the Dask client to close. Used to avoid teardown errors
# see https://distributed.dask.org/en/latest/api.html#distributed.Client.close
CLIENT_CLOSE_TIMEOUT = 120
Expand Down Expand Up @@ -1253,5 +1257,9 @@ def test_sklearn_integration(estimator, check, client):
# this test is separate because it takes a not-yet-constructed estimator
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__
name = estimator.__class__.__name__
if sk_version > parse_version("0.23"):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
Estimator = estimator
else:
Estimator = estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try just passing an instance always, so we could avoid parsing the scikit-learn version, but it looks like scikit-learn 0.22.x only supported passing a class in this method 😆

Here's what I got changing the code to the following on scikit-learn==0.22.0

@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
    name = estimator.__class__.__name__
    sklearn_checks.check_parameters_default_constructible(name, estimator)

E TypeError: 'DaskLGBMRegressor' object is not callable