Skip to content

Commit

Permalink
Update of "Gracefully accept 'n_jobs', a common sklearn parameter, in…
Browse files Browse the repository at this point in the history
… NearestNeighbors Estimator" (rapidsai#4267)

This pull request partially solves [[FEA] rapidsai#3461](rapidsai#3461)

This quick-fix has been created to enable cuML's NearestNeighbor estimator to gracefully accept sklearns 'n_jobs' parameter as a pass-through.

The purpose of making this quick fix is to allow Imbalanced-Learn samplers to rely on cuML's NearestNeighbor estimator, without producing an error when setting the estimators n_jobs parameter .set_params(**{"n_jobs": self.n_jobs})

The[ original PR ](rapidsai#4178 address this issue was not sufficient, as [`set_params()`](https://github.com/rapidsai/cuml/blob/2fee231ac28d982f64c4a746c25be19750812e81/python/cuml/common/base.pyx#L248) will still raise a ValueError if "n_jobs" is not returned by [`get_param_names()`](https://github.com/rapidsai/cuml/blob/2fee231ac28d982f64c4a746c25be19750812e81/python/cuml/neighbors/nearest_neighbors.pyx#L453)

Authors:
  - https://github.com/NV-jpt
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: rapidsai#4267
  • Loading branch information
NV-jpt authored Nov 15, 2021
1 parent 9f00a11 commit 4512078
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class NearestNeighbors(Base,
def get_param_names(self):
return super().get_param_names() + \
["n_neighbors", "algorithm", "metric",
"p", "metric_params", "algo_params"]
"p", "metric_params", "algo_params", "n_jobs"]

@staticmethod
def _build_metric_type(metric):
Expand Down
5 changes: 5 additions & 0 deletions python/cuml/test/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,8 @@ def test_haversine_fails_high_dimensions():
algorithm='brute')

cunn.fit(data).kneighbors(data)


def test_n_jobs_parameter_passthrough():
cunn = cuKNN()
cunn.set_params(n_jobs=12)

0 comments on commit 4512078

Please sign in to comment.