diff --git a/python/cuml/benchmark/automated/dask/conftest.py b/python/cuml/benchmark/automated/dask/conftest.py index 4e406ed5a6..8d2bb4e49c 100644 --- a/python/cuml/benchmark/automated/dask/conftest.py +++ b/python/cuml/benchmark/automated/dask/conftest.py @@ -18,6 +18,7 @@ from dask_cuda import initialize from dask_cuda import LocalCUDACluster +from dask_cuda.utils_test import IncreasedCloseTimeoutNanny from dask.distributed import Client enable_tcp_over_ucx = True @@ -28,7 +29,11 @@ @pytest.fixture(scope="module") def cluster(): - cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + cluster = LocalCUDACluster( + protocol="tcp", + scheduler_port=0, + worker_class=IncreasedCloseTimeoutNanny, + ) yield cluster cluster.close() @@ -54,6 +59,7 @@ def ucx_cluster(): enable_tcp_over_ucx=enable_tcp_over_ucx, enable_nvlink=enable_nvlink, enable_infiniband=enable_infiniband, + worker_class=IncreasedCloseTimeoutNanny, ) yield cluster cluster.close() diff --git a/python/cuml/tests/dask/conftest.py b/python/cuml/tests/dask/conftest.py index 29f09a44c9..3c6311dc03 100644 --- a/python/cuml/tests/dask/conftest.py +++ b/python/cuml/tests/dask/conftest.py @@ -4,6 +4,7 @@ from dask_cuda import initialize from dask_cuda import LocalCUDACluster +from dask_cuda.utils_test import IncreasedCloseTimeoutNanny from dask.distributed import Client enable_tcp_over_ucx = True @@ -14,7 +15,11 @@ @pytest.fixture(scope="module") def cluster(): - cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + cluster = LocalCUDACluster( + protocol="tcp", + scheduler_port=0, + worker_class=IncreasedCloseTimeoutNanny, + ) yield cluster cluster.close() @@ -40,6 +45,7 @@ def ucx_cluster(): enable_tcp_over_ucx=enable_tcp_over_ucx, enable_nvlink=enable_nvlink, enable_infiniband=enable_infiniband, + worker_class=IncreasedCloseTimeoutNanny, ) yield cluster cluster.close()