Skip to content

Commit

Permalink
Update sklearnex estimators to support sklearn 1.5 (#1794)
Browse files Browse the repository at this point in the history
* Update TSNE and PCA to support sklearn 1.5pre-release

* Update PCA solver branching

* Add PCA parameter in test_n_jobs_support.py

* Fix n_iter check in TSNE for sklearn<1.2

* Apply isort

* Move to sklearnex.utils.get_namespace

* Fix PCA.fit_transform

* Update PCA solver selection

* Update PCA algorithm doc

* Fix solver name in tests

* Fix svd solver contraint

* Fix PCA solver constraint

* Fix check_feature_names warning and deselect RF feature importance tests

* Change KMeans estimator doc strings

* Update solver selection

* Update solver selection

* Revert test de-deselection
  • Loading branch information
Alexsandruss authored Apr 27, 2024
1 parent 8c7a928 commit ba8206d
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 175 deletions.
124 changes: 15 additions & 109 deletions daal4py/sklearn/cluster/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,23 +258,6 @@ def is_string(s, target_str):


def _fit(self, X, y=None, sample_weight=None):
"""Compute k-means clustering.
Parameters
----------
X : array-like or sparse matrix, shape=(n_samples, n_features)
Training instances to cluster. It must be noted that the data
will be converted to C ordering, which will cause a memory
copy if the given data is not C-contiguous.
y : Ignored
not used, present here for API consistency by convention.
sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None)
"""
init = self.init
if sklearn_check_version("1.1"):
if sklearn_check_version("1.2"):
Expand Down Expand Up @@ -447,26 +430,6 @@ def _daal4py_check_test_data(self, X):


def _predict(self, X, sample_weight=None):
"""Predict the closest cluster each sample in X belongs to.
In the vector quantization literature, `cluster_centers_` is called
the code book and each value returned by `predict` is the index of
the closest code in the code book.
Parameters
----------
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
New data to predict.
sample_weight : array-like, shape (n_samples,), optional
The weights for each observation in X. If None, all observations
are assigned equal weight (default: None)
Returns
-------
labels : array, shape [n_samples,]
Index of the cluster each sample belongs to.
"""
check_is_fitted(self)

X = _daal4py_check_test_data(self, X)
Expand Down Expand Up @@ -614,86 +577,29 @@ def __init__(

@support_usm_ndarray()
def fit(self, X, y=None, sample_weight=None):
"""
Compute k-means clustering.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training instances to cluster. It must be noted that the data
will be converted to C ordering, which will cause a memory
copy if the given data is not C-contiguous.
If a sparse matrix is passed, a copy will be made if it's not in
CSR format.
y : Ignored
Not used, present here for API consistency by convention.
sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
.. versionadded:: 0.20
Returns
-------
self : object
Fitted estimator.
"""
return _fit(self, X, y=y, sample_weight=sample_weight)

@support_usm_ndarray()
def predict(
self, X, sample_weight="deprecated" if sklearn_check_version("1.3") else None
):
"""
Predict the closest cluster each sample in X belongs to.
if sklearn_check_version("1.5"):

In the vector quantization literature, `cluster_centers_` is called
the code book and each value returned by `predict` is the index of
the closest code in the code book.
@support_usm_ndarray()
def predict(self, X):
return _predict(self, X)

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
New data to predict.
sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
else:

Returns
-------
labels : ndarray of shape (n_samples,)
Index of the cluster each sample belongs to.
"""
return _predict(self, X, sample_weight=sample_weight)
@support_usm_ndarray()
def predict(
self, X, sample_weight="deprecated" if sklearn_check_version("1.3") else None
):
return _predict(self, X, sample_weight=sample_weight)

@support_usm_ndarray()
def fit_predict(self, X, y=None, sample_weight=None):
"""
Compute cluster centers and predict cluster index for each sample.
Convenience method; equivalent to calling fit(X) followed by
predict(X).
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
New data to transform.
y : Ignored
Not used, present here for API consistency by convention.
sample_weight : array-like of shape (n_samples,), default=None
The weights for each observation in X. If None, all observations
are assigned equal weight.
Returns
-------
labels : ndarray of shape (n_samples,)
Index of the cluster each sample belongs to.
"""
return super().fit_predict(X, y, sample_weight)

score = support_usm_ndarray()(KMeans_original.score)

fit.__doc__ = KMeans_original.fit.__doc__
predict.__doc__ = KMeans_original.predict.__doc__
fit_predict.__doc__ = KMeans_original.fit_predict.__doc__
score.__doc__ = KMeans_original.score.__doc__
71 changes: 27 additions & 44 deletions daal4py/sklearn/manifold/_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,52 +44,15 @@
class TSNE(BaseTSNE):
__doc__ = BaseTSNE.__doc__

if sklearn_check_version("1.2"):
_parameter_constraints: dict = {**BaseTSNE._parameter_constraints}

@support_usm_ndarray()
def fit_transform(self, X, y=None):
"""
Fit X into an embedded space and return that transformed output.
Parameters
----------
X : ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
If the metric is 'precomputed' X must be a square distance
matrix. Otherwise it contains a sample per row. If the method
is 'exact', X may be a sparse matrix of type 'csr', 'csc'
or 'coo'. If the method is 'barnes_hut' and the metric is
'precomputed', X may be a precomputed sparse graph.
y : None
Ignored.
Returns
-------
X_new : ndarray of shape (n_samples, n_components)
Embedding of the training data in low-dimensional space.
"""
return super().fit_transform(X, y)

@support_usm_ndarray()
def fit(self, X, y=None):
"""
Fit X into an embedded space.
Parameters
----------
X : ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
If the metric is 'precomputed' X must be a square distance
matrix. Otherwise it contains a sample per row. If the method
is 'exact', X may be a sparse matrix of type 'csr', 'csc'
or 'coo'. If the method is 'barnes_hut' and the metric is
'precomputed', X may be a precomputed sparse graph.
y : None
Ignored.
Returns
-------
X_new : array of shape (n_samples, n_components)
Embedding of the training data in low-dimensional space.
"""
return super().fit(X, y)

def _daal_tsne(self, P, n_samples, X_embedded):
Expand All @@ -101,11 +64,27 @@ def _daal_tsne(self, P, n_samples, X_embedded):
# * final optimization with momentum at 0.8

# N, nnz, n_iter_without_progress, n_iter
size_iter = [[n_samples], [P.nnz], [self.n_iter_without_progress], [self.n_iter]]
size_iter = [
[n_samples],
[P.nnz],
[self.n_iter_without_progress],
[self._max_iter if sklearn_check_version("1.5") else self.n_iter],
]

# Pass params to daal4py backend
if daal_check_version((2023, "P", 1)):
size_iter.extend([[self._EXPLORATION_N_ITER], [self._N_ITER_CHECK]])
size_iter.extend(
[
[
(
self._EXPLORATION_MAX_ITER
if sklearn_check_version("1.5")
else self._EXPLORATION_N_ITER
)
],
[self._N_ITER_CHECK],
]
)

size_iter = np.array(size_iter, dtype=P.dtype)

Expand Down Expand Up @@ -255,8 +234,9 @@ def _fit(self, X, skip_num_points=0):
)
)

if self.n_iter < 250:
raise ValueError("n_iter should be at least 250")
if not sklearn_check_version("1.2"):
if self.n_iter < 250:
raise ValueError("n_iter should be at least 250")

n_samples = X.shape[0]

Expand Down Expand Up @@ -423,3 +403,6 @@ def _fit(self, X, skip_num_points=0):
neighbors=neighbors_nn,
skip_num_points=skip_num_points,
)

fit.__doc__ = BaseTSNE.fit.__doc__
fit_transform.__doc__ = BaseTSNE.fit_transform.__doc__
8 changes: 8 additions & 0 deletions deselected_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ deselected_tests:
- inspection/tests/test_partial_dependence.py::test_partial_dependence_easy_target[2-est2] >=0.23 darwin
- inspection/tests/test_partial_dependence.py::test_partial_dependence_easy_target[2-est3] >=0.23 darwin

# Sklearnex RandomForestClassifier RNG is different from scikit-learn and daal4py
# resulting in different feature importances for small number of trees (10).
# Issue dissappears with bigger number of trees (>=20)
- inspection/tests/test_permutation_importance.py::test_permutation_importance_correlated_feature_regression_pandas[0.5-1]
- inspection/tests/test_permutation_importance.py::test_permutation_importance_correlated_feature_regression_pandas[0.5-2]
- inspection/tests/test_permutation_importance.py::test_permutation_importance_correlated_feature_regression_pandas[1.0-1]
- inspection/tests/test_permutation_importance.py::test_permutation_importance_correlated_feature_regression_pandas[1.0-2]

# Random forest classifier selects a different most-important feature
# Feature importances:
# scikit-learn-intelex [0. 0.00553064 0.71323666 0.2812327 ]
Expand Down
4 changes: 2 additions & 2 deletions doc/sources/algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Dimensionality reduction
* - `PCA`
- All parameters are supported except:

- ``svd_solver`` != `'full'`
- ``svd_solver`` not in [`'full'`, `'covariance_eigh'`]
- Sparse data is not supported
* - `TSNE`
- All parameters are supported except:
Expand Down Expand Up @@ -340,7 +340,7 @@ Dimensionality reduction
* - `PCA`
- All parameters are supported except:

- ``svd_solver`` != `'full'`
- ``svd_solver`` not in [`'full'`, `'covariance_eigh'`]
- Sparse data is not supported

Nearest Neighbors
Expand Down
Loading

0 comments on commit ba8206d

Please sign in to comment.