diff --git a/examples/sklearnex/pca_spmd.py b/examples/sklearnex/pca_spmd.py new file mode 100644 index 0000000000..3972e9547e --- /dev/null +++ b/examples/sklearnex/pca_spmd.py @@ -0,0 +1,40 @@ +# =============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== + +import numpy as np +from mpi4py import MPI +import dpctl +from sklearnex.spmd.decomposition import PCA + + +def get_data(data_seed): + ns, nf = 300, 30 + drng = np.random.default_rng(data_seed) + X = drng.random(size=(ns, nf)) + return X + + +q = dpctl.SyclQueue("gpu") +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +X = get_data(rank) + +pca = PCA(n_components=2).fit(X, q) + +print(f"Singular values on rank {rank}:\n", pca.singular_values_) +print(f"Explained variance Ratio on rank {rank}:\n", pca.explained_variance_ratio_) diff --git a/onedal/__init__.py b/onedal/__init__.py index d72b76887b..db30bf0334 100644 --- a/onedal/__init__.py +++ b/onedal/__init__.py @@ -49,4 +49,4 @@ __all__ += ['basic_statistics', 'linear_model'] if _is_dpc_backend: - __all__ += ['spmd.basic_statistics', 'spmd.linear_model'] + __all__ += ['spmd.basic_statistics', 'spmd.decomposition', 'spmd.linear_model',] diff --git a/onedal/decomposition/pca.cpp b/onedal/decomposition/pca.cpp index c509d9f09f..6498e7bf42 100644 --- a/onedal/decomposition/pca.cpp +++ b/onedal/decomposition/pca.cpp @@ -136,10 +136,12 @@ ONEDAL_PY_INIT_MODULE(decomposition) { using task_list = types; auto sub = m.def_submodule("decomposition"); - - ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list); + #ifdef ONEDAL_DATA_PARALLEL_SPMD + ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list_spmd, task_list); + #else + ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list); + #endif ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list, task_list); - ONEDAL_PY_INSTANTIATE(init_model, sub, task_list); ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list); ONEDAL_PY_INSTANTIATE(init_infer_result, sub, task_list); diff --git a/onedal/decomposition/pca.py b/onedal/decomposition/pca.py index caa7d5f796..eda23a1ced 100644 --- a/onedal/decomposition/pca.py +++ b/onedal/decomposition/pca.py @@ -18,7 +18,8 @@ from onedal import _backend from ..common._policy import _get_policy -from ..datatypes._data_conversion import from_table, to_table, _convert_to_supported +from ..datatypes._data_conversion import from_table, to_table +from ..datatypes import _convert_to_supported from daal4py.sklearn._utils import sklearn_check_version @@ -43,17 +44,20 @@ def get_onedal_params(self, data): 'is_deterministic': self.is_deterministic } - def fit(self, X, y, queue): + def _get_policy(self, queue, *data): + return _get_policy(queue, *data) + + def fit(self, X, queue): n_samples, n_features = X.shape n_sf_min = min(n_samples, n_features) - policy = _get_policy(queue, X, y) - + policy = self._get_policy(queue, X) # TODO: investigate why np.ndarray with OWNDATA=FALSE flag # fails to be converted to oneDAL table if isinstance(X, np.ndarray) and not X.flags['OWNDATA']: X = X.copy() - X, y = _convert_to_supported(policy, X, y) + X = _convert_to_supported(policy, X) + params = self.get_onedal_params(X) cov_result = _backend.covariance.compute( policy, @@ -99,10 +103,11 @@ def fit(self, X, y, queue): def _create_model(self): m = _backend.decomposition.dim_reduction.model() m.eigenvectors = to_table(self.components_) + self._onedal_model = m return m def predict(self, X, queue): - policy = _get_policy(queue, X) + policy = self._get_policy(queue, X) model = self._create_model() X = _convert_to_supported(policy, X) diff --git a/onedal/primitives/covariance.cpp b/onedal/primitives/covariance.cpp index d154431d38..a2c989e729 100644 --- a/onedal/primitives/covariance.cpp +++ b/onedal/primitives/covariance.cpp @@ -78,7 +78,11 @@ ONEDAL_PY_INIT_MODULE(covariance) { using namespace dal::covariance; auto sub = m.def_submodule("covariance"); - ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute); + #ifdef ONEDAL_DATA_PARALLEL_SPMD + ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list_spmd, task::compute); + #else + ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute); + #endif ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task::compute); } diff --git a/onedal/spmd/__init__.py b/onedal/spmd/__init__.py index 659ff7de3c..9ac25b4370 100644 --- a/onedal/spmd/__init__.py +++ b/onedal/spmd/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. #=============================================================================== -__all__ = ['linear_model', 'basic_statistics'] +__all__ = ['basic_statistics', 'decomposition', 'linear_model'] diff --git a/onedal/spmd/decomposition/__init__.py b/onedal/spmd/decomposition/__init__.py new file mode 100644 index 0000000000..eda7b9fc14 --- /dev/null +++ b/onedal/spmd/decomposition/__init__.py @@ -0,0 +1,19 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from .pca import PCA + +__all__ = ['PCA'] diff --git a/onedal/spmd/decomposition/pca.py b/onedal/spmd/decomposition/pca.py new file mode 100644 index 0000000000..5790089313 --- /dev/null +++ b/onedal/spmd/decomposition/pca.py @@ -0,0 +1,28 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + + +from ...common._spmd_policy import _get_spmd_policy +from onedal.decomposition.pca import PCA as PCABatch + + +class BasePCASPMD: + def _get_policy(self, queue, *data): + return _get_spmd_policy(queue) + + +class PCA(BasePCASPMD, PCABatch): + pass diff --git a/setup.py b/setup.py index 0b4293f225..dff6ffe3a3 100644 --- a/setup.py +++ b/setup.py @@ -475,8 +475,9 @@ def run(self): ] if ONEDAL_VERSION >= 20230100 else [] ) + ( ['onedal.spmd', - 'onedal.spmd.linear_model', - 'onedal.spmd.basic_statistics' + 'onedal.spmd.basic_statistics', + 'onedal.spmd.decomposition', + 'onedal.spmd.linear_model' ] if build_distribute else [])), package_data={ 'daal4py.oneapi': [ diff --git a/sklearnex/preview/decomposition/pca.py b/sklearnex/preview/decomposition/pca.py index dd595d3f76..72bddc4ab1 100755 --- a/sklearnex/preview/decomposition/pca.py +++ b/sklearnex/preview/decomposition/pca.py @@ -214,7 +214,7 @@ def _onedal_fit(self, X, y=None, queue=None): 'method': "precomputed", } self._onedal_estimator = onedal_PCA(**onedal_params) - self._onedal_estimator.fit(X, y, queue=queue) + self._onedal_estimator.fit(X, queue=queue) self._save_attributes() U = None diff --git a/sklearnex/spmd/__init__.py b/sklearnex/spmd/__init__.py index 659ff7de3c..9ac25b4370 100644 --- a/sklearnex/spmd/__init__.py +++ b/sklearnex/spmd/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. #=============================================================================== -__all__ = ['linear_model', 'basic_statistics'] +__all__ = ['basic_statistics', 'decomposition', 'linear_model'] diff --git a/sklearnex/spmd/decomposition/__init__.py b/sklearnex/spmd/decomposition/__init__.py new file mode 100644 index 0000000000..eda7b9fc14 --- /dev/null +++ b/sklearnex/spmd/decomposition/__init__.py @@ -0,0 +1,19 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from .pca import PCA + +__all__ = ['PCA'] diff --git a/sklearnex/spmd/decomposition/pca.py b/sklearnex/spmd/decomposition/pca.py new file mode 100644 index 0000000000..5bf6eb63ab --- /dev/null +++ b/sklearnex/spmd/decomposition/pca.py @@ -0,0 +1,21 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from onedal.spmd.decomposition import PCA + +# TODO: +# Currently it uses `onedal` module interface. +# Add sklearnex dispatching. diff --git a/tests/run_examples.py b/tests/run_examples.py index c861b18b43..a5f43b7f07 100755 --- a/tests/run_examples.py +++ b/tests/run_examples.py @@ -144,6 +144,7 @@ def check_library(rule): req_device = defaultdict(lambda: []) req_device['basic_statistics_spmd.py'] = ["gpu"] req_device['linear_regression_spmd.py'] = ["gpu"] +req_device['pca_spmd.py'] = ["gpu"] req_device['sycl/gradient_boosted_regression_batch.py'] = ["gpu"] req_library = defaultdict(lambda: []) @@ -152,6 +153,7 @@ def check_library(rule): req_library['gbt_cls_model_create_from_xgboost_batch.py'] = ['xgboost'] req_library['gbt_cls_model_create_from_catboost_batch.py'] = ['catboost'] req_library['linear_regression_spmd.py'] = ['dpctl', 'mpi4py'] +req_library['pca_spmd.py'] = ['dpctl', 'mpi4py'] req_os = defaultdict(lambda: [])