Skip to content

Commit

Permalink
Remove direct Scikit-learn imports (#4431)
Browse files Browse the repository at this point in the history
* FIX Remove hard sklearn imports

* FIX Missing whitespace

* FIX minor error

* FIX PEP8 fixes
  • Loading branch information
dantegd authored Dec 8, 2021
1 parent e9fb48c commit 4ce5bd6
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import functools
import timeit
import numbers
from sklearn.base import clone
from sklearn.utils import Bunch
from cuml.common.import_utils import has_sklearn

if has_sklearn():
from sklearn.base import clone
from sklearn.utils import Bunch
from contextlib import contextmanager
from collections import defaultdict
import warnings
Expand All @@ -30,6 +33,7 @@
import cuml
from cuml.internals.global_settings import _global_settings_data
from cuml.common.array_sparse import SparseCumlArray
from cuml.common.import_utils import has_sklearn
from cuml.internals import _deprecate_pos_args
from ..utils.skl_dependencies import TransformerMixin, BaseComposition, \
BaseEstimator
Expand Down Expand Up @@ -549,6 +553,9 @@ def __init__(self,
n_jobs=None,
transformer_weights=None,
verbose=False):
if not has_sklearn():
raise ImportError("Scikit-learn is needed to use the "
"Column Transformer")
if not transformers:
warnings.warn('Transformers are required')
self.transformers = transformers
Expand Down
16 changes: 9 additions & 7 deletions python/cuml/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
#

from cuml.model_selection._split import train_test_split
from sklearn.model_selection import GridSearchCV
from cuml.common.import_utils import has_sklearn

if has_sklearn():
from sklearn.model_selection import GridSearchCV

GridSearchCV.__doc__ = """
This code is developed and maintained by scikit-learn and imported
by cuML to maintain the familiar sklearn namespace structure.
cuML includes tests to ensure full compatibility of these wrappers
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers.\n\n""" + GridSearchCV.__doc__
GridSearchCV.__doc__ = """
This code is developed and maintained by scikit-learn and imported
by cuML to maintain the familiar sklearn namespace structure.
cuML includes tests to ensure full compatibility of these wrappers
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers.\n\n""" + GridSearchCV.__doc__

__all__ = ['train_test_split', 'GridSearchCV']
6 changes: 5 additions & 1 deletion python/cuml/multiclass/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
#

import cuml.internals
import sklearn.multiclass

from cuml.common.array import CumlArray
from cuml.common.base import Base
from cuml.common.import_utils import has_sklearn
from cuml.common.mixins import ClassifierMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_host_array
Expand Down Expand Up @@ -112,6 +112,10 @@ def fit(self, X, y) -> 'MulticlassClassifier':
"""
Fit a multiclass classifier.
"""
if not has_sklearn():
raise ImportError("Scikit-learn is needed to use "
"MulticlassClassifier derived classes.")
import sklearn.multiclass
if self.strategy == 'ovr':
self.multiclass_estimator = sklearn.multiclass.\
OneVsRestClassifier(self.estimator, n_jobs=None)
Expand Down
5 changes: 4 additions & 1 deletion python/cuml/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
# limitations under the License.
#

from sklearn.pipeline import Pipeline, make_pipeline
from cuml.common.import_utils import has_sklearn

if has_sklearn():
from sklearn.pipeline import Pipeline, make_pipeline

disclaimer = """
This code is developed and maintained by scikit-learn and imported
Expand Down

0 comments on commit 4ce5bd6

Please sign in to comment.