Skip to content

Commit

Permalink
PCA SPMD python interfaces (#1211)
Browse files Browse the repository at this point in the history
* Initial pca spmd changes

* pca spmd example

* flake8 fixes

* simplifying examples

* flake8 fixes

* flake8 changes in examples

* fixing pca call from sklearnex than onedal

* removing y from fit params in pca

* y remove from examples

* flake8 changes to examples

* flake8 changes after pull

* Fixes for PR comments on examples and class names

* Update runexamples with pca_spmd

* Update onedal/__init__.py

Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>

* Update setup.py

Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>

* Add featues and samples in example

---------

Co-authored-by: jui.mhatre <jmhatre@jflmkl116.jf.intel.com>
Co-authored-by: Alexander Andreev <alexander.andreev@intel.com>
Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>
  • Loading branch information
4 people authored Mar 22, 2023
1 parent 5bb4c31 commit 2405630
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 16 deletions.
40 changes: 40 additions & 0 deletions examples/sklearnex/pca_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

import numpy as np
from mpi4py import MPI
import dpctl
from sklearnex.spmd.decomposition import PCA


def get_data(data_seed):
ns, nf = 300, 30
drng = np.random.default_rng(data_seed)
X = drng.random(size=(ns, nf))
return X


q = dpctl.SyclQueue("gpu")
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

X = get_data(rank)

pca = PCA(n_components=2).fit(X, q)

print(f"Singular values on rank {rank}:\n", pca.singular_values_)
print(f"Explained variance Ratio on rank {rank}:\n", pca.explained_variance_ratio_)
2 changes: 1 addition & 1 deletion onedal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@
__all__ += ['basic_statistics', 'linear_model']

if _is_dpc_backend:
__all__ += ['spmd.basic_statistics', 'spmd.linear_model']
__all__ += ['spmd.basic_statistics', 'spmd.decomposition', 'spmd.linear_model',]
8 changes: 5 additions & 3 deletions onedal/decomposition/pca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ ONEDAL_PY_INIT_MODULE(decomposition) {

using task_list = types<task::dim_reduction>;
auto sub = m.def_submodule("decomposition");

ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list);
#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list_spmd, task_list);
#else
ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list);
#endif
ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list, task_list);

ONEDAL_PY_INSTANTIATE(init_model, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_result, sub, task_list);
Expand Down
17 changes: 11 additions & 6 deletions onedal/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from onedal import _backend
from ..common._policy import _get_policy
from ..datatypes._data_conversion import from_table, to_table, _convert_to_supported
from ..datatypes._data_conversion import from_table, to_table
from ..datatypes import _convert_to_supported
from daal4py.sklearn._utils import sklearn_check_version


Expand All @@ -43,17 +44,20 @@ def get_onedal_params(self, data):
'is_deterministic': self.is_deterministic
}

def fit(self, X, y, queue):
def _get_policy(self, queue, *data):
return _get_policy(queue, *data)

def fit(self, X, queue):
n_samples, n_features = X.shape
n_sf_min = min(n_samples, n_features)

policy = _get_policy(queue, X, y)

policy = self._get_policy(queue, X)
# TODO: investigate why np.ndarray with OWNDATA=FALSE flag
# fails to be converted to oneDAL table
if isinstance(X, np.ndarray) and not X.flags['OWNDATA']:
X = X.copy()
X, y = _convert_to_supported(policy, X, y)
X = _convert_to_supported(policy, X)

params = self.get_onedal_params(X)
cov_result = _backend.covariance.compute(
policy,
Expand Down Expand Up @@ -99,10 +103,11 @@ def fit(self, X, y, queue):
def _create_model(self):
m = _backend.decomposition.dim_reduction.model()
m.eigenvectors = to_table(self.components_)
self._onedal_model = m
return m

def predict(self, X, queue):
policy = _get_policy(queue, X)
policy = self._get_policy(queue, X)
model = self._create_model()

X = _convert_to_supported(policy, X)
Expand Down
6 changes: 5 additions & 1 deletion onedal/primitives/covariance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ ONEDAL_PY_INIT_MODULE(covariance) {
using namespace dal::covariance;

auto sub = m.def_submodule("covariance");
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute);
#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list_spmd, task::compute);
#else
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute);
#endif
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task::compute);
}

Expand Down
2 changes: 1 addition & 1 deletion onedal/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#===============================================================================

__all__ = ['linear_model', 'basic_statistics']
__all__ = ['basic_statistics', 'decomposition', 'linear_model']
19 changes: 19 additions & 0 deletions onedal/spmd/decomposition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from .pca import PCA

__all__ = ['PCA']
28 changes: 28 additions & 0 deletions onedal/spmd/decomposition/pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================


from ...common._spmd_policy import _get_spmd_policy
from onedal.decomposition.pca import PCA as PCABatch


class BasePCASPMD:
def _get_policy(self, queue, *data):
return _get_spmd_policy(queue)


class PCA(BasePCASPMD, PCABatch):
pass
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,9 @@ def run(self):
] if ONEDAL_VERSION >= 20230100 else []
) + (
['onedal.spmd',
'onedal.spmd.linear_model',
'onedal.spmd.basic_statistics'
'onedal.spmd.basic_statistics',
'onedal.spmd.decomposition',
'onedal.spmd.linear_model'
] if build_distribute else [])),
package_data={
'daal4py.oneapi': [
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/preview/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _onedal_fit(self, X, y=None, queue=None):
'method': "precomputed",
}
self._onedal_estimator = onedal_PCA(**onedal_params)
self._onedal_estimator.fit(X, y, queue=queue)
self._onedal_estimator.fit(X, queue=queue)
self._save_attributes()

U = None
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#===============================================================================

__all__ = ['linear_model', 'basic_statistics']
__all__ = ['basic_statistics', 'decomposition', 'linear_model']
19 changes: 19 additions & 0 deletions sklearnex/spmd/decomposition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from .pca import PCA

__all__ = ['PCA']
21 changes: 21 additions & 0 deletions sklearnex/spmd/decomposition/pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from onedal.spmd.decomposition import PCA

# TODO:
# Currently it uses `onedal` module interface.
# Add sklearnex dispatching.
2 changes: 2 additions & 0 deletions tests/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def check_library(rule):
req_device = defaultdict(lambda: [])
req_device['basic_statistics_spmd.py'] = ["gpu"]
req_device['linear_regression_spmd.py'] = ["gpu"]
req_device['pca_spmd.py'] = ["gpu"]
req_device['sycl/gradient_boosted_regression_batch.py'] = ["gpu"]

req_library = defaultdict(lambda: [])
Expand All @@ -152,6 +153,7 @@ def check_library(rule):
req_library['gbt_cls_model_create_from_xgboost_batch.py'] = ['xgboost']
req_library['gbt_cls_model_create_from_catboost_batch.py'] = ['catboost']
req_library['linear_regression_spmd.py'] = ['dpctl', 'mpi4py']
req_library['pca_spmd.py'] = ['dpctl', 'mpi4py']

req_os = defaultdict(lambda: [])

Expand Down

0 comments on commit 2405630

Please sign in to comment.