diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 6b112b705a3e..5f7784190e4b 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -28,12 +28,16 @@ from dask.array.utils import assert_eq from dask.distributed import Client, LocalCluster, default_client, wait from distributed.utils_test import client, cluster_fixture, gen_cluster, loop +from pkg_resources import parse_version from scipy.sparse import csr_matrix from scipy.stats import spearmanr +from sklearn import __version__ as sk_version from sklearn.datasets import make_blobs, make_regression from .utils import make_ranking +sk_version = parse_version(sk_version) + # time, in seconds, to wait for the Dask client to close. Used to avoid teardown errors # see https://distributed.dask.org/en/latest/api.html#distributed.Client.close CLIENT_CLOSE_TIMEOUT = 120 @@ -1253,5 +1257,9 @@ def test_sklearn_integration(estimator, check, client): # this test is separate because it takes a not-yet-constructed estimator @pytest.mark.parametrize("estimator", list(_tested_estimators())) def test_parameters_default_constructible(estimator): - name, Estimator = estimator.__class__.__name__, estimator.__class__ + name = estimator.__class__.__name__ + if sk_version >= parse_version("0.24"): + Estimator = estimator + else: + Estimator = estimator.__class__ sklearn_checks.check_parameters_default_constructible(name, Estimator)