diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index d63ff3ba00..4d63c8b0ac 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -453,7 +453,7 @@ class NearestNeighbors(Base, def get_param_names(self): return super().get_param_names() + \ ["n_neighbors", "algorithm", "metric", - "p", "metric_params", "algo_params"] + "p", "metric_params", "algo_params", "n_jobs"] @staticmethod def _build_metric_type(metric): diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 35f1c5229e..18e3d9505d 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -657,3 +657,8 @@ def test_haversine_fails_high_dimensions(): algorithm='brute') cunn.fit(data).kneighbors(data) + + +def test_n_jobs_parameter_passthrough(): + cunn = cuKNN() + cunn.set_params(n_jobs=12)